Collapsed Gibbs Sampling for Dirichlet Process Gaussian Mixture Models

I really enjoyed the pedagogy of Edwin Chen’s introduction to infinite mixture models, but I was a little disappointed that it does not go as far as presenting the details of the Dirichlet process Gaussian mixture model (DPGMM), as he uses sklearn’s variational Bayes DPGMM implementation.

For this reason, I will try and give here sufficient information to implement a DPGMM with collapsed Gibbs sampling. This is not an in-depth evaluation of which conjugate priors to use, nor an analysis of the parameters and hyper-parameters (that should have their own priors! ;)).

Prerequisites

On Dirichlet processes, Chinese Restaurant processes, Indian Buffet processes, there is the excellent blog post by Edwin Chen. Another excellent introduction to Dirichlet processes is provided by Frigyik, Kapila and Gupta.

If you lack some knowledge about clustering or density estimation (unsupervised learning), you can read Chapters 20 (p. 284) to (at least) 22 of MacKay’s ITILA, that you can find as a free ebook; or chapter 9 of Bishop’s PRML. As a refresher, the Wikipedia article on mixture models, and the sklearn documentation on GMM are more efficient.

DPGMM: the model

Let’s say we have $N$ observations and $K$ clusters, $i \in [1\dots N]$ is the indice for the observations, while $k \in [1\dots K]$ is the indice for the clusters. With $z_i$ the cluster assignment of observation $x_i$, and $\theta_k$ the parameter of mixture $k$:

So, the generative story of a DPGMM is as follows:

$\pi \sim Stick(\alpha)$ (mixing rates)
$z_i \sim \pi$ (cluster assignments)
$\theta_k \sim H(\lambda)$ (parameters)
$x_i \sim F(\theta_{z_i})$ (values)

Fitting the data

Notation:

Let’s decompose the probability that the observation $i$ belongs to cluster $k$ into its two independent factors:

Then:

is the marginal likelihood of all the data assigned to cluster $k$, including $i$.

If $z_i = k^*$ (new cluster) then:

Conjugate priors

Now we should choose $H$ for it to be conjugate to $F$ and have easy to compute parameters posterior. As we want $F$ to be multivariate normal: we can look on Wikipedia’s page of conjugate priors under multivariate normal with unknown $\mu$ and $\Sigma$ to see that $H$ should be normal-inverse-Wishart with prior parameters:

• $\mu_0$ initial mean guess [In my code further, I set it to the mean of whole the dataset.]
• $\kappa_0$ mean fraction (smoothing parameter) [A common value is 1. I set it to 0.]
• $\nu_0$ degrees of freedom [I set it to the number of dimensions.]
• $\Psi_0$ pairwise deviation product (matrix) [I set it to $10 \times I_d$ ($I_d$ is the $d\times d$ identity matrix). Indentity matrix makes this prior Gaussian circular, the $10$ factor should be dependant on the dataset, for instance on the mean distance between points.]

This gives us MAP estimates on parameters, for one of the clusters:

with $\tilde{x}$ the sample mean and $C=\sum_{i=1}^n (x_i-\tilde{x})(x_i-\tilde{x})^T$.

Set $\kappa_{0} = 0$ to have no effect of the prior on the posterior mean. This reduces to MLE estimates if:

So now we can compute the posterior predictive for cluster $k$ evaluated at $x_i$

Collapsed Gibbs sampling

Here is the pseudo-code of collapsed Gibbs sampling adapted from algorithm 3 of Neal’s seminal paper:

while (not converged on mus and sigmas):
for each i = 1 : N in random order do:
remove x[i]'s sufficient statistics from old cluster z[i]
if any cluster is empty, remove it and decrease K
for each k = 1 : K do
compute P_k(x[i]) = P(x[i] | x[-i]=k)
N[k,-i] = dim(x[-i]=k)
compute P(z[i]=k | z[-i], Data) = N[k,-i] / (alpha + N - 1)
compute P*(x[i]) = P(x[i] | lambda)
compute P(z[i]=* | z[-i], Data) = alpha / (alpha + N - 1)
normalize P(z[i] | ...)
sample z[i] from P(z[i] | ...)
add x[i]'s sufficient statistics to new cluster z[i]
(possibly increase K)


Results

Here is the result of our implementation of collapsed Gibbs sampling DPGMM compared to scikit-learn’s implementation of variational Bayes DPGMM:

Code

Here is my quick-and-dirty code implementing this version of Gibbs sampling for DPGMM. You may want to comment out scikit-learn (that I used for the comparison above) if you do not have it installed.

 
 Posted by syhw Mar 10th, 2013 Bayesian, Dirichlet process, Gibbs sampling, mixture models Tweet « From hacks to Bayesian probability 
 
 Recent Posts Collapsed Gibbs Sampling for Dirichlet Process Gaussian Mixture Models From hacks to Bayesian probability GitHub Repos Status updating... @SnippyHolloW on GitHub $.domReady(function(){ if (!window.jXHR){ var jxhr = document.createElement('script'); jxhr.type = 'text/javascript'; jxhr.src = '/javascripts/libs/jXHR.js'; var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(jxhr, s); } github.showRepos({ user: 'SnippyHolloW', count: 4, skip_forks: true, target: '#gh_repos' }); }); Latest Tweets Status updating...$.domReady(function(){ getTwitterFeed("syhw", 4, false); }); Follow @syhw My Pinboard Fetching linkroll... My Pinboard Bookmarks » var linkroll = 'pinboard_linkroll'; //id target for pinboard list var pinboard_user = "syhw"; //id target for pinboard list var pinboard_count = 5; //id target for pinboard list (function(){ var pinboardInit = document.createElement('script'); pinboardInit.type = 'text/javascript'; pinboardInit.async = true; pinboardInit.src = '/javascripts/pinboard.js'; document.getElementsByTagName('head')[0].appendChild(pinboardInit); })(); 
 
 Copyright © 2013 - syhw - Powered by Octopress (function(d, s, id) { var js, fjs = d.getElementsByTagName(s)[0]; if (d.getElementById(id)) {return;} js = d.createElement(s); js.id = id; js.src = "//connect.facebook.net/en_US/all.js#appId=212934732101925&xfbml=1"; fjs.parentNode.insertBefore(js, fjs); }(document, 'script', 'facebook-jssdk')); (function() { var script = document.createElement('script'); script.type = 'text/javascript'; script.async = true; script.src = 'https://apis.google.com/js/plusone.js'; var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(script, s); })(); (function(){ var twitterWidgets = document.createElement('script'); twitterWidgets.type = 'text/javascript'; twitterWidgets.async = true; twitterWidgets.src = 'http://platform.twitter.com/widgets.js'; document.getElementsByTagName('head')[0].appendChild(twitterWidgets); })();