100
Variational inference for stochastic differential equations Dennis Prangle Newcastle University, UK October 2018

Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

  • Upload
    others

  • View
    10

  • Download
    0

Embed Size (px)

Citation preview

Page 1: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference for stochastic differentialequations

Dennis Prangle

Newcastle University, UK

October 2018

Page 2: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Acknowledgements and reference

Joint work with Tom Ryder, Steve McGough, Andy Golightly

Supported by EPSRC cloud computing for big data CDT

and NVIDIA academic GPU grant

Published in ICML 2018

http://proceedings.mlr.press/v80/ryder18a.html

https://github.com/Tom-Ryder/VIforSDEs

Page 3: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Overview

Background

Stochastic differential equationsVariational inference

Variational inference for SDEs

Example

Conclusion

Page 4: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Stochastic differential equations (SDEs)

Page 5: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

SDEs

SDE perturbs a differential equation with random noise

Defines a diffusion process

Function(s) which evolves randomly over time

SDE describes its instantaneous behaviour

0.1400

0.1425

0.1450

0.1475

0.1500

0.0 2.5 5.0 7.5 10.0

time

x

Page 6: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

SDE applications

Finance/econometrics (Black & Scholes, 1973; Eraker, 2001)

Biology/ecology (Gillespie, 2000; Golightly & Wilkinson, 2011)

Physics (van Kampen, 2007)

Epidemiology (Fuchs, 2013)

Page 7: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Univariate SDE definition

dXt = α(Xt , θ)dt +√β(Xt , θ)dWt , X0 = x0.

Xt is random variable at time t

α is drift term

β is diffusion term

Wt is Brownian motion process

θ is unknown parameters

x0 is initial conditions (can depend on θ)

Formalisation requires stochastic calculus (e.g. Ito)

Page 8: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Univariate SDE definition

dXt = α(Xt , θ)dt +√β(Xt , θ)dWt , X0 = x0.

Xt is random variable at time t

α is drift term

β is diffusion term

Wt is Brownian motion process

θ is unknown parameters

x0 is initial conditions (can depend on θ)

Formalisation requires stochastic calculus (e.g. Ito)

Page 9: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Univariate SDE definition

dXt = α(Xt , θ)dt +√β(Xt , θ)dWt , X0 = x0.

Xt is random variable at time t

α is drift term

β is diffusion term

Wt is Brownian motion process

θ is unknown parameters

x0 is initial conditions (can depend on θ)

Formalisation requires stochastic calculus (e.g. Ito)

Page 10: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Multivariate SDE definition

dXt = α(Xt , θ)dt +√β(Xt , θ)dWt , X0 = x0.

Xt is random vector

α is drift vector

β is diffusion matrix

Wt is Brownian motion process vector

θ is unknown parameters

x0 is initial conditions

Page 11: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Problem statement

We observe diffusion at several time points

(Usually partial noisy observations)

Primary goal: infer parameters θ

e.g. their posterior distribution

Secondary goal: also infer diffusion states x

Page 12: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Problem statement

We observe diffusion at several time points

(Usually partial noisy observations)

Primary goal: infer parameters θ

e.g. their posterior distribution

Secondary goal: also infer diffusion states x

Page 13: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

SDE discretisation

Hard to work with exact SDEs

Common approach is discretisation

Approximate on a discrete grid of timese.g. evenly spaced 0,∆τ, 2∆τ, 3∆τ . . .

Simplest method is Euler-Maruyama

Page 14: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Euler-Maruyama discretisation

xi+1 = xi + α(xi , θ)∆τ +√β(xi , θ)∆τ zi+1

xi is state at ith time in grid

∆τ is grid timestep

zi+1 is vector of independent N(0, 1) draws

●●

0.29900

0.29925

0.29950

0.29975

0.30000

0.00 0.01 0.02 0.03 0.04

time

x

Page 15: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Euler-Maruyama discretisation

xi+1 = xi + α(xi , θ)∆τ +√β(xi , θ)∆τ zi+1

xi is state at ith time in grid

∆τ is grid timestep

zi+1 is vector of independent N(0, 1) draws

●●

0.29900

0.29925

0.29950

0.29975

0.30000

0.00 0.01 0.02 0.03 0.04

time

x

Page 16: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Posterior distribution

Let p(θ) be prior density for parameters

Posterior distribution is

p(θ, x |y) ∝ p(θ)p(x |θ)p(y |x , θ)

(prior × SDE model × observation model)

where

p(x |θ) product of normal densities for state increments(i.e. xi+1 − xi values)and p(y |x , θ) product of normal densities at observation times

n.b. right hand side is unnormalised posterior p(θ, x , y)

Page 17: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Posterior distribution

Let p(θ) be prior density for parameters

Posterior distribution is

p(θ, x |y) ∝ p(θ)p(x |θ)p(y |x , θ)

(prior × SDE model × observation model)

where

p(x |θ) product of normal densities for state increments(i.e. xi+1 − xi values)and p(y |x , θ) product of normal densities at observation times

n.b. right hand side is unnormalised posterior p(θ, x , y)

Page 18: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Posterior distribution

Let p(θ) be prior density for parameters

Posterior distribution is

p(θ, x |y) ∝ p(θ)p(x |θ)p(y |x , θ)

(prior × SDE model × observation model)

where

p(x |θ) product of normal densities for state increments(i.e. xi+1 − xi values)and p(y |x , θ) product of normal densities at observation times

n.b. right hand side is unnormalised posterior p(θ, x , y)

Page 19: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Posterior inference

Could use sampling methods

e.g. MCMC (Markov chain Monte Carlo), SMC (sequentialMonte Carlo)

But posterior high dimensional and lots of dependency

Very challenging for sampling methods

One approach is to use bridging (next slide)

Page 20: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Bridge constructs

Propose x via a bridge construct

Use within Monte Carlo inference

Bridge construct is approx to conditioned diffusion

(usually conditioning just on next observation)

Derived mathematically

Various bridges used in practice

Struggle with highly non-linear paths and large gaps betweenobservation times

Choosing bridges and designing new ones hard work!

We automate this using machine learning

Page 21: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Bridge constructs

Propose x via a bridge construct

Use within Monte Carlo inference

Bridge construct is approx to conditioned diffusion

(usually conditioning just on next observation)

Derived mathematically

Various bridges used in practice

Struggle with highly non-linear paths and large gaps betweenobservation times

Choosing bridges and designing new ones hard work!

We automate this using machine learning

Page 22: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Bridge constructs

Propose x via a bridge construct

Use within Monte Carlo inference

Bridge construct is approx to conditioned diffusion

(usually conditioning just on next observation)

Derived mathematically

Various bridges used in practice

Struggle with highly non-linear paths and large gaps betweenobservation times

Choosing bridges and designing new ones hard work!

We automate this using machine learning

Page 23: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference

Page 24: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference

Goal: inference on posterior p(θ|y)

Given unnormalised version p(θ, y)

Introduce q(θ;φ)

Family of approximate posteriorsControlled by parameters φ

Idea: find φ giving best approximate posterior

Converts Bayesian inference into optimisation problem

n.b. outputs approximation to posterior

Page 25: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference

Goal: inference on posterior p(θ|y)

Given unnormalised version p(θ, y)

Introduce q(θ;φ)

Family of approximate posteriorsControlled by parameters φ

Idea: find φ giving best approximate posterior

Converts Bayesian inference into optimisation problem

n.b. outputs approximation to posterior

Page 26: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference

Goal: inference on posterior p(θ|y)

Given unnormalised version p(θ, y)

Introduce q(θ;φ)

Family of approximate posteriorsControlled by parameters φ

Idea: find φ giving best approximate posterior

Converts Bayesian inference into optimisation problem

n.b. outputs approximation to posterior

Page 27: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference

Goal: inference on posterior p(θ|y)

Given unnormalised version p(θ, y)

Introduce q(θ;φ)

Family of approximate posteriorsControlled by parameters φ

Idea: find φ giving best approximate posterior

Converts Bayesian inference into optimisation problem

n.b. outputs approximation to posterior

Page 28: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference

VI finds φ minimising KL(q(θ;φ)||p(θ|y))

Equivalent to maximising ELBO (evidence lower bound),

L(φ) = Eθ∼q(·;φ)

[log

p(θ, y)

q(θ;φ)

](Jordan, Ghahramani, Jaakkola, Saul 1999)

Page 29: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference

Optimum q often finds posterior mode well

But usually overconcentrated!(unless family of qs allows very good matches)

(source: Yao, Vehtari, Simpson, Gelman 2018)

Page 30: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Maximising the ELBO

Several optimisation methods:

Variational calculusParametric optimisation (various flavours)

Page 31: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Maximising the ELBO

“Reparameterisation trick”(Kingma and Welling 2014; Rezende, Mohamed and Wierstra 2014;

Titsias and Lazaro-Gredilla 2014)

Write θ ∼ q(·;φ) as θ = g(ε, φ) where

g inverible functionε random variable of fixed distribution

Example shortly!

Page 32: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Maximising the ELBO

ELBO is

L(φ) = Eε[

logp(θ, y)

q(θ;φ)

]⇒ ∇φL(φ) = Eε

[∇φ log

p(θ, y)

q(θ;φ)

]Unbiased Monte-Carlo gradient estimate

∇φL(φ) = ∇φ logp(θ, y)

q(θ;φ)

where θ = g(ε, φ) for some ε sample

(can average batch of estimates to reduce variance)

Get gradients using automatic differentiation

Optimise L(φ) by stochastic optimisation

Easy to code in Tensorflow, PyTorch etc

Page 33: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Maximising the ELBO

ELBO is

L(φ) = Eε[

logp(θ, y)

q(θ;φ)

]⇒ ∇φL(φ) = Eε

[∇φ log

p(θ, y)

q(θ;φ)

]Unbiased Monte-Carlo gradient estimate

∇φL(φ) = ∇φ logp(θ, y)

q(θ;φ)

where θ = g(ε, φ) for some ε sample

(can average batch of estimates to reduce variance)

Get gradients using automatic differentiation

Optimise L(φ) by stochastic optimisation

Easy to code in Tensorflow, PyTorch etc

Page 34: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Maximising the ELBO

ELBO is

L(φ) = Eε[

logp(θ, y)

q(θ;φ)

]⇒ ∇φL(φ) = Eε

[∇φ log

p(θ, y)

q(θ;φ)

]Unbiased Monte-Carlo gradient estimate

∇φL(φ) = ∇φ logp(θ, y)

q(θ;φ)

where θ = g(ε, φ) for some ε sample

(can average batch of estimates to reduce variance)

Get gradients using automatic differentiation

Optimise L(φ) by stochastic optimisation

Easy to code in Tensorflow, PyTorch etc

Page 35: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Maximising the ELBO

ELBO is

L(φ) = Eε[

logp(θ, y)

q(θ;φ)

]⇒ ∇φL(φ) = Eε

[∇φ log

p(θ, y)

q(θ;φ)

]Unbiased Monte-Carlo gradient estimate

∇φL(φ) = ∇φ logp(θ, y)

q(θ;φ)

where θ = g(ε, φ) for some ε sample

(can average batch of estimates to reduce variance)

Get gradients using automatic differentiation

Optimise L(φ) by stochastic optimisation

Easy to code in Tensorflow, PyTorch etc

Page 36: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Example: mean field approximation

Simplest variational approximation

Assumes θ ∼ N(µ,Σ) for Σ diagonal

Then θ = g(ε, φ) = µ+ Σ1/2ε

where:

ε ∼ N(0, I)φ = (µ,Σ)

Makes strong unrealistic assumptions about posterior!

Page 37: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Example: mean field approximation

Simplest variational approximation

Assumes θ ∼ N(µ,Σ) for Σ diagonal

Then θ = g(ε, φ) = µ+ Σ1/2ε

where:

ε ∼ N(0, I)φ = (µ,Σ)

Makes strong unrealistic assumptions about posterior!

Page 38: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Example: mean field approximation

Simplest variational approximation

Assumes θ ∼ N(µ,Σ) for Σ diagonal

Then θ = g(ε, φ) = µ+ Σ1/2ε

where:

ε ∼ N(0, I)φ = (µ,Σ)

Makes strong unrealistic assumptions about posterior!

Page 39: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference: summary

Define family of approximate posteriors q(θ;φ)

So that θ = g(ε, φ)

Can use optimisation to minimise KL divergence

Output is approximate posterior

(often good point estimate but overconcentrated)

Page 40: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference for SDEs

Page 41: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference for SDEs

We want posterior p(θ, x |y) for SDE model

Defineq(θ, x ;φ) = q(θ;φθ)q(x |θ;φx)

We use mean-field approx for q(θ;φθ)

Leaves choice of q(x |θ;φx)

Page 42: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational inference for SDEs

q(x |θ;φx) should approximate p(x |θ, y): “conditioneddiffusion”

SDE theory suggests this is itself a diffusion(see e.g. Rogers and Williams 2013)

But with different drift and diffusion to original SDE

Page 43: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational approximation to diffusion

We define q(x |θ;φx) to be a diffusion

We let drift α and diffusion β depend on:

Parameters θMost recent x and t valuesDetails of next observations

To get flexible parametric functions we use neural network

φx is neural network parameters (weights and biases)

Page 44: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational approximation to diffusion

t = 0

x0 y

Neural network θ

Drift α, diffusion β φx

z1 ∼ N(0, I) x1 = x0 + α∆t +

√β∆tz1

Page 45: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational approximation to diffusion

x0 t = 1

x1 y

Neural network θ

Drift α, diffusion β φx

z2 ∼ N(0, I) x2 = x1 + α∆t +

√β∆tz2

Page 46: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational approximation to diffusion

Typically we take

xi+1 = xi + α∆t +

√β∆tzi+1

Sometimes want to ensure non-negativity of xs

So we use

xi+1 = h

(xi + α∆t +

√β∆tzi+1

)Where h outputs non-negative values e.g. softplus function

h(z) = log(1 + ez)

Page 47: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational approximation to diffusion

Typically we take

xi+1 = xi + α∆t +

√β∆tzi+1

Sometimes want to ensure non-negativity of xs

So we use

xi+1 = h

(xi + α∆t +

√β∆tzi+1

)Where h outputs non-negative values e.g. softplus function

h(z) = log(1 + ez)

Page 48: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Variational approximation to diffusion

Drift and diffusion calculated from neural network

Used to calculate x1

Fed back into same neural network to get next drift anddiffusion

. . .

Recurrent neural network structure

Powerful but tricky to scale up

Page 49: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Algorithm summary

Initialise φθ, φx

Begin loop

Sample θ from q(θ;φθ) (independent normals)

Sample x from q(x ; θ, φx) (run RNN)

Calculate ELBO gradient

∇L(φ) = ∇φ logp(θ, x , y)

q(θ;φθ)q(x |θ;φx)

Update φθ, φx by stochastic optimisation

End loop

(n.b. can use larger Monte Carlo batch size)

Page 50: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Algorithm summary

Initialise φθ, φx

Begin loop

Sample θ from q(θ;φθ) (independent normals)

Sample x from q(x ; θ, φx) (run RNN)

Calculate ELBO gradient

∇L(φ) = ∇φ logp(θ, x , y)

q(θ;φθ)q(x |θ;φx)

Update φθ, φx by stochastic optimisation

End loop

(n.b. can use larger Monte Carlo batch size)

Page 51: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Example

Page 52: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Lotka-Volterra example

Classic population dynamics model

Two populations: prey and predators

Three processes

Prey growthPredator growth (by consuming prey)Predator death

Many variations exist

We use SDE from Golightly and Wilkinson (2011)

Page 53: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Lotka-Volterra example

Prey population at time t is Ut

Predator population at time t is Vt

Drift

α(Xt , θ) =

(θ1Ut − θ2UtVt

θ2UtVt − θ3Vt

)Diffusion

β(Xt , θ) =

(θ1Ut + θ2UtVt −θ2UtVt

−θ2UtVt θ3Vt + θ2UtVt

)Parameters:

θ1 controls prey growthθ2 controls predator growth by consuming preyθ3 controls predator death

Page 54: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Lotka-Volterra example

X X XX

X

X

X

X0

200

400

600

800

0 10 20 30 40

time

popu

latio

n type

XX

predator

prey

Page 55: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Settings - model

IID priors: log θi ∼ N(0, 32) for i = 1, 2, 3

Discretisation time step ∆τ = 0.1

Observation variance Σ = I - small relative to typicalpopulation sizes

Challenging scenario:

Non-linear diffusion pathsSmall observation varianceLong gaps between observations

Page 56: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Settings - variational inference

Batch size 50 for gradient estimate

4 layer neural network (20 ReLU units / layer)

Softplus transformation to avoid proposing negativepopulation levels

Various methods to avoid numerical problems in training

Page 57: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 58: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 59: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 60: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 61: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 62: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 63: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 64: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 65: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 66: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 67: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 68: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 69: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 70: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 71: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 72: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 73: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 74: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 75: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 76: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 77: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 78: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 79: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 80: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 81: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 82: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 83: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 84: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 85: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 86: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 87: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 88: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 89: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 90: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 91: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 92: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 93: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Page 94: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Parameter inference results

0.45 0.50 0.55

θ1

0

20

Den

sity

0.0022 0.0025 0.0028

θ2

0

2500

5000

0.25 0.30

θ3

0

25

50

Black: true parameter values

Blue: variational output

Green: importance sampling (shows over-concentration)

Computing time: ≈ 2 hours on a desktop PC

Page 95: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Results

Can also do inference under misspecifed model:

0 2 4 6 8 10Time

0

200

400

600

800

1000

Pop

ulat

ion

Prey

Predator

Page 96: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Conclusion

Page 97: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Summary

Variational approach to approx Bayesian inference for SDEs

Little tuning required

Results in a few hours on a desktop PC

We observe good estimation of posterior mode

Page 98: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

More in paper

Real data epidemic example

Diffusions with unobserved components

Page 99: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Current/future work

Normalising flows instead of RNNs

Big data - long or wide

Other models: state space models, HMMs, MJPs. . .(discrete variables challenging!)

Model comparison/improvement

Real applications - get in touch with any suggestions!

Page 100: Variational inference for stochastic differential equations...Variational inference Goal: inference on posterior p( jy) Given unnormalised version p( ;y) Introduce q( ;˚) Family of

Acknowledgements and reference

Joint work with Tom Ryder, Steve McGough, Andy Golightly

Supported by EPSRC cloud computing for big data CDT

and NVIDIA academic GPU grant

http://proceedings.mlr.press/v80/ryder18a.html

https://github.com/Tom-Ryder/VIforSDEs

Presented at ICML 2018