Upload
others
View
1
Download
0
Embed Size (px)
Citation preview
Learning outcomes:
I Be able to fit some basic zero inflation and hurdle modelsI Be able to understand and fit some multinomial modelling examples
2 / 33
Zero-inflation and hurdle models
I Let’s introduce some new data. This is data from an experiment on whiteflies:wf = read.csv('../data/whitefly.csv')head(wf)
## imm week block trt n live plantid## 1 15 1 3 5 12 11 1## 2 16 2 3 5 8 6 1## 3 28 3 3 5 10 10 1## 4 17 4 3 5 10 8 1## 5 9 5 3 5 10 10 1## 6 28 6 3 5 10 10 1
The response variable here is the count imm of immature whiteflies, and the explanatoryvariables are block (plant number), week, and treatment treat.
3 / 33
Look at those zeros!barplot(table(wf$imm),
main = 'Number of immature whiteflies')
0 2 4 6 8 10 12 14 16 18 20 22 25 27 29 31 33 35 37 41 43 51 53 55 75
Number of immature whiteflies
050
100
150
200
250
300
4 / 33
A first model
I These are count data so a Poisson distribution is a good startI Let’s consider a basic Poisson distribution model for Yi , i = 1, . . . ,N observations:
Yi ∼ Po(λi)
log(λi) = βtrti
I We’ll only consider the treatment effect but we could run much more complicatedmodels with e.g. other covariates and interactions
5 / 33
Fitting the model in JAGSmodel_code = 'model{
# Likelihoodfor (i in 1:N) {
y[i] ~ dpois(lambda[i])y_pp[i] ~ dpois(lambda[i])log(lambda[i]) <- beta_trt[trt[i]]
}# Priorsfor (j in 1:N_trt) {
beta_trt[j] ~ dnorm(beta_mean, beta_sd^-2)}beta_mean ~ dnorm(0, 10^-2)beta_sd ~ dt(0, 5^-2, 1)T(0,)
}'
6 / 33
Running the model
jags_run =jags(data = list(N = nrow(wf),
N_trt = length(unique(wf$trt)),y = wf$imm,trt = wf$trt),
parameters.to.save = c('beta_trt', 'y_pp','beta_mean', 'beta_sd'),
model.file = textConnection(model_code))
7 / 33
Resultsplot(jags_run)
80% interval for each chain R−hat−2000
−2000
0
0
2000
2000
4000
4000
6000
6000
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
beta_meanbeta_sdbeta_trt[1][2][3][4][5][6]deviance
medians and 80% intervals
beta_mean
0
1
2
beta_sd
1
2
3
beta_trt
−2024
111111111 222222222 333333333 444444444 555555555 666666666
deviance
4560456545704575
y_pp
0
20
40
111111111 222222222 333333333 444444444 555555555 666666666 777777777 888888888 999999999 101010101010101010 121212121212121212 141414141414141414 161616161616161616 181818181818181818 202020202020202020 222222222222222222 242424242424242424 262626262626262626 282828282828282828 303030303030303030 323232323232323232 343434343434343434 363636363636363636 383838383838383838 404040404040404040
*
Bugs model at "4", fit using jags, 3 chains, each with 2000 iterations (first 1000 discarded)
Some clear treatment effects - treatment 5 in particular
8 / 33
Did the model actually fit well?y_pp = jags_run$BUGSoutput$mean$y_pppar(mfrow=c(1,2))hist(wf$imm, breaks = seq(0, max(wf$imm)),
main = 'Data vs posterior predictive fit')hist(y_pp, breaks = seq(0, max(wf$imm)), add = TRUE, col = 'gray')plot(wf$imm, y_pp); abline(a = 0, b = 1)
Data vs posterior predictive fit
wf$imm
Fre
quen
cy
0 20 40 60
015
030
0
0 20 40 60
010
20
wf$imm
y_pp
9 / 33
What about the zeros?
I One way of broadening the distribution is through over-dispersion which we havealready met:
log(λi) ∼ N(βtrti , σ2)
I However this doesn’t really solve the problem of excess zeros
I Instead there are a specific class of models called zero-inflation models which use aspecific probability distribution. The zero-inflated Poisson (ZIP) with ZI parameterq0 is written as:
p(y |λ) ={
q0 + (1− q0)× Poisson(0, λ) if y = 0(1− q0)× Poisson(y , λ) if y 6= 0
10 / 33
Fitting models with custom probability distributions
I The Zero-inflated Poisson distribution is not included in Stan or JAGS by default.We have to create it
I It’s possible to create new probability distributions in StanI It’s a little bit fiddly to do so in JAGS, we have to use some tricksI We will use JAGS to create a mixture of Poisson distributions; A Poisson(0)
distribution for the zeros, and a Poisson(λ) distribution for the rest
11 / 33
Fitting the ZIP in JAGSmodel_code = 'model{
# Likelihoodfor (i in 1:N) {
y[i] ~ dpois(lambda[i] * z[i] + 0.0001)y_pp[i] ~ dpois(lambda[i] * z[i] + 0.0001)log(lambda[i]) <- beta_trt[trt[i]]z[i] ~ dbinom(q_0, 1)
}# Priorsfor (j in 1:N_trt) {
beta_trt[j] ~ dnorm(beta_mean, beta_sd^-2)}beta_mean ~ dnorm(0, 10^-2)beta_sd ~ dt(0, 5^-2, 1)T(0,)q_0 ~ dunif(0, 1)
}'
12 / 33
Running the model
jags_run =jags(data = list(N = nrow(wf),
N_trt = length(unique(wf$trt)),y = wf$imm,trt = wf$trt),
parameters.to.save = c('beta_trt','q_0', 'y_pp','beta_mean', 'beta_sd'),
model.file = textConnection(model_code))
13 / 33
Results
plot(jags_run)
80% interval for each chain R−hat0
0
1000
1000
2000
2000
3000
3000
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
beta_meanbeta_sdbeta_trt[1][2][3][4][5][6]devianceq_0
medians and 80% intervals
beta_mean
11.52
2.5
beta_sd
0.51
1.52
beta_trt
01234
111111111 222222222 333333333 444444444 555555555 666666666
deviance
2600
2650
2700
q_0
0.480.5
0.520.540.56
y_pp
0
20
40
111111111 222222222 333333333 444444444 555555555 666666666 777777777 888888888 999999999 101010101010101010 121212121212121212 141414141414141414 161616161616161616 181818181818181818 202020202020202020 222222222222222222 242424242424242424 262626262626262626 282828282828282828 303030303030303030 323232323232323232 343434343434343434 363636363636363636 383838383838383838 404040404040404040
*
Bugs model at "5", fit using jags, 3 chains, each with 2000 iterations (first 1000 discarded)
14 / 33
Did it work any better? - code
y_pp = jags_run$BUGSoutput$mean$y_pppar(mfrow=c(1,2))hist(wf$imm, breaks = seq(0, max(wf$imm)),
main = 'Data vs posterior predictive fit')hist(y_pp, breaks = seq(0, max(wf$imm)), add = TRUE, col = 'gray')plot(wf$imm, y_pp); abline(a = 0, b = 1)
Data vs posterior predictive fit
wf$imm
Fre
quen
cy
0 20 40 60
015
030
0
0 20 40 60
010
20
wf$imm
y_pp
15 / 33
Some more notes on Zero-inflated Poisson
I This model seems to predict the number of zeros pretty well. It would also beinteresting to perhaps try having a different probability of zeros (q0) for differenttreatments
I It might be that the other covariates explain some of the zero behaviourI We could further add in both zero-inflation and over-dispersion
16 / 33
An alternative: hurdle models
I ZI models work by having a parameter (here q0) which is the probability of gettinga zero, and so the probability of getting a Poisson value (which could also be azero) is 1 minus this value
I An alternative (which is slightly more complicated) is a hurdle model where q0represents the probability of the only way of getting a zero. With probability (1-q0)we end up with a special Poisson random variable which has to take values 1 ormore
I In some ways this is richer than a ZI model since zeros can be deflated or inflatedI This is a bit fiddlier to fit in JAGS
17 / 33
A hurdle-Poisson model in JAGSmodel_code = 'model{
# Likelihoodfor (i in 1:N) {
y[i] ~ dpois(lambda[i])T(1,)log(lambda[i]) <- beta_trt[trt[i]]
}for(i in 1:N_0) {
y_0[i] ~ dbin(q_0, 1)}# Priorsfor (j in 1:N_trt) {
beta_trt[j] ~ dnorm(beta_mean, beta_sd^-2)}beta_mean ~ dnorm(0, 10^-2)beta_sd ~ dt(0, 5^-2, 1)T(0,)q_0 ~ dunif(0, 1)
}'
18 / 33
Running the model
jags_run =jags(data = list(N = nrow(wf[wf$imm > 0,]),
N_trt = length(unique(wf$trt)),y = wf$imm[wf$imm > 0],y_0 = as.integer(wf$imm == 0),N_0 = nrow(wf),trt = wf$trt[wf$imm > 0]),
parameters.to.save = c('beta_trt', 'q_0','beta_mean', 'beta_sd'),
model.file = textConnection(model_code))
19 / 33
Resultsplot(jags_run)
80% interval for each chain R−hat0
0
2000
2000
4000
4000
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
beta_meanbeta_sdbeta_trt[1][2][3][4][5][6]deviance
medians and 80% intervals
beta_mean
11.52
2.5
beta_sd
0.51
1.52
beta_trt
01234
111111111 222222222 333333333 444444444 555555555 666666666
deviance
3465347034753480
q_0
0.460.480.5
0.520.54
Bugs model at "4", fit using jags, 3 chains, each with 2000 iterations (first 1000 discarded)
20 / 33
Some final notes on ZI models
I To complete the Poisson-Hurdle fit we would need to simulate from a truncatedPoisson model. This starts to get very fiddly though - see the jags_examplesrepository for worked examples
I We can extend these models further by using a better count distribution such asthe negative binomial which has an extra over-dispersion parameter
I We can also add covariates into the zero-inflation component, though it is notalways clear whether this is desirable
21 / 33
The multinomial distribution
I Multinomial data can be thought of as multivariate discrete data
I It’s usually used in two different scenarios:1. For classification, when you have an observation falling into a single one of K possible
categories2. For multinomial regression, where you have a set of counts which sum to a known
value N
I We will just consider the multinomial regression case, whereby we haveobservations yi = [yi1, . . . , yiK ] where the sum
∑Kk=1 yik = Ni is fixed
I The classification version is a simplification of the regression version
22 / 33
Some new data! - pollen
pollen = read.csv('../data/pollen.csv')head(pollen)
## GDD5 MTCO Abies Alnus Betula Picea Pinus.D Quercus.D Gramineae## 1 1874 -7.9 0 50 158 7 721 22 0## 2 1623 -5.5 0 38 28 302 537 19 0## 3 1475 -4.7 0 276 183 110 136 0 0## 4 1360 -8.8 0 111 354 141 364 0 0## 5 1295 -6.9 0 91 50 151 708 0 0## 6 1539 -7.8 0 51 194 82 673 0 0
These data are pollen counts of 7 varieties of pollen from modern samples with twocovariates
23 / 33
Some plots
I The two covariates represent the length of the growing season (GDD5) andharshness of the winter (MTCO)
I The task is to find which climate regimes each pollen variety favours
0 1000 2000 3000 4000 5000 6000 7000
0.0
0.2
0.4
0.6
0.8
1.0
Length of growing season
Est
imat
ed p
ropo
rtio
n
24 / 33
A multinomial model
I The multinomial distribution is often written as:
[yi1, . . . , yiK ] ∼ Mult(Si , {pi1, . . . , piK})
or, for short:yi ∼ Mult(Si , pi)
I The key parameters here are the probability vectors pi . It’s these we want to use alink function on to include the covariates
I We need to be careful as each must sum to one:∑K
k=1 pik = 1. Any link functionmust satisfy this constraint
25 / 33
Prior distributions on probability vectors
I When K = 2 we’re back the binomial-logit we met in the first day, and we can usethe logit link function
I When K > 2 a common function to use is the soft-max function:
pik = exp(zik)∑Kj=1 exp(zij)
I This is a generalisation of the logit functionI The next layer of our model sets, e.g.:
zik = β0 + β1GDD5i + γ2MTCOi + . . .
26 / 33
JAGS codemodel_code = 'model{
# Likelihoodfor (i in 1:N) { # Observaton loops
y[i,] ~ dmulti(p[i,], S[i])for(j in 1:M) { # Category loop
exp_z[i,j] <- exp(z[i,j])p[i,j] <- exp_z[i,j]/sum(exp_z[i,])z[i,j] <- beta[j,]%*%x[i,]
}}# Priorfor(j in 1:M) {
for(k in 1:K) {beta[j,k] ~ dnorm(0, 0.1^-2)
}}
}'
27 / 33
Let’s fit it (first 500 obs only)
model_data = list(N = nrow(pollen[1:500,]),y = pollen[1:500,3:9],x = cbind(1, scale(cbind(pollen[1:500,1:2],
pollen[1:500,1:2]^2))),S = pollen[1:500,10],K = 5, # Number of covarsM = 7) # Number of categories
# Run the modelmodel_run = jags(data = model_data,
parameters.to.save = c("p"),model.file = textConnection(model_code))
28 / 33
Results 1plot(model_run)
80% interval for each chain R−hat242100
242100
242200
242200
242300
242300
242400
242400
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+
1 1.5 2+deviance
medians and 80% intervals
deviance
242100
242200
242300
242400
p
0.012
0.014
0.016
0.018
111111111111111111
222222222 333333333 444444444 555555555 666666666 777777777 888888888 999999999 101010101010101010 121212121212121212 141414141414141414 161616161616161616 181818181818181818 202020202020202020 222222222222222222 242424242424242424 262626262626262626 282828282828282828 303030303030303030 323232323232323232 343434343434343434 363636363636363636 383838383838383838 404040404040404040
*
Bugs model at "6", fit using jags, 3 chains, each with 2000 iterations (first 1000 discarded)
29 / 33
Results 2
0.0 0.2 0.4 0.6 0.8 1.0
0.0
0.1
0.2
0.3
0.4
0.5
0.6
True proportion of Betula
Est
imat
ed p
ropo
rtio
n of
Bet
ula
30 / 33
Notes about this model
I This model is not going to fit very well, since it is unlikely that a linear relationshipbetween the covariates and the pollen counts will match the data
I It might be better to use e.g. a spline model (covered in the next class)I Similarly we might need some complex interactions between the covariates as they
are strongly linkedI We have constrained the parameters here so that the slopes and intercepts borrow
strength across species. Does this make sense? What else could we do?
31 / 33
Some final notes about multinomial models
I These models can be a pain to deal with as there are tricky constraints on the βparameters to make them all sum to 1. Instead it’s often easier to just put a tightprior distribution on them, e.g. β ∼ N(0, 0.1)
I The softmax function is one choice but there are lots of others (logistic ratios, theDirichlet distribution, . . . )
I Whilst the classification version of this model just has binary yi (with just a single 1in it, i.e. Si = 1) most packages (including JAGS and Stan) have a specialdistribution (e.g. dcat in JAGS) for this situation
32 / 33
Summary
I We have fitted some zero inflated and hurdle Poisson models in JAGSI We have seen a simple multinomial logistic regression
33 / 33