I found myself thinking about using Turing.jl to develop a simple classifier for scRNA-Seq data, so I thought it probably would be helpful to write up a simple demo highlighting how the package works.1
using Turing
using StatsPlots
using Random
Random.seed!(42);
Turing.jl is a package for probabilistic modeling and bayesian inference in Julia. If you don’t have a formal background in ML or statistics (like me) this might be a bit intimidating, but if you are already working with data science packages and want to get more into modeling using MCMC, Turing seems like a great place to start.
A simple Gaussian model
Here we start by defining perhaps the simplest possible model, a single Gaussian distribution. The model function accepts data $x$ as an argument and defines prior distributions for the mean $\mu$ and standard deviation $\sigma$.
@model function simple_gaussian(x)
μ ~ Normal(0, 5) # centered at 0 with a standard deviation of 5
σ ~ Exponential(1) # exponential with a rate of 1
x ~ Normal(μ, σ)
end
simple_gaussian (generic function with 2 methods)
Now, we can use Turing to use Markov Chain Monte Carlo (MCMC)2 methods to sample from the posterior distribution of the model parameters given some data. Here, we generate some random data from a normal distribution (with $\mu$ = 0, $\sigma$ = 1) and then use a specific sampler, HMC() or hamiltonian monte carlo to draw samples from the posterior distribution of these parameters.
chain = sample(simple_gaussian(randn(100)), HMC(0.05, 10), 1000, progress=true)
Chains MCMC chain (1000×14×1 Array{Union{Missing, Float64}, 3}):
Iterations = 1:1:1000
Number of chains = 1
Samples per chain = 1000
Wall duration = 0.16 seconds
Compute duration = 0.16 seconds
parameters = μ, σ
internals = logprior, loglikelihood, logjoint, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Use `describe(chains)` for summary statistics and quantiles.
Turing includes a ploting function to quickly visualize the results of our sampling.
p = plot(chain; size = (700, 300), fmt=:svg, margin=(3, :mm))
mean(chain[:μ]), mean(chain[:σ])
(-0.11427439942358561, 0.9806775416893045)
We can see that these estimates are close to the true values of 0 and 1, which is what we would expect given that we generated the data from a standard normal distribution.
Now, what if we updated the input data to reflect a different distribution?
d2 = Normal(5, 2) # a gaussian distribution with mean 5 and standard deviation 2
x = rand(d2, 50) # generate samples
h = histogram(
x;
bins=30,
normalize=false,
legend = false,
title="Some data",
size=(350, 350),
fmt=:svg,
)
ylabel!("Frequency")
xlabel!("x")
Now we can use the same model to estimate the parameters of this new distribution. This time we use a different sampler, NUTS() (No-U-Turn Sampler).
chain = sample(simple_gaussian(x), NUTS(), 1000, progress=true)
mean(chain[:μ]), mean(chain[:σ])
(4.987134902761898, 1.9834541544269715)
You can see that it estimates μ and σ pretty well, even from 50 samples.
Gaussian Mixture Models (GMMs)
Now, let’s consider the case where our data is generated from a mixture of two Gaussian distributions. Turing makes pretty easy to define and sample over such a model.
@model function two_gaussian_mixture(x)
# Priors for the parameters
μ1 ~ Normal(0, 5)
μ2 ~ Normal(0, 5)
σ1 ~ Exponential(1)
σ2 ~ Exponential(1)
π ~ Dirichlet(2, 1.0)
# define the GMM likelihood
x ~ MixtureModel([Normal(μ1, σ1), Normal(μ2, σ2)], π)
end
two_gaussian_mixture (generic function with 2 methods)
Now we can generate some data from a mixture of two Gaussians and use our model to estimate the parameters.
d1 = Normal(0, 1)
d2 = Normal(5, 2)
x = vcat(rand(d1, 300), rand(d2, 50)) # generate samples from both distributions
h = histogram(
x;
bins=30,
normalize=true,
legend = true,
label = "samples",
title="Some more data",
size=(400, 400),
fmt=:svg,
)
ylabel!("Density")
xlabel!("x")
# add a line plot of the true distribution
x_range = -5:0.1:10
pdf_values_1 = 0.86 * pdf.(d1, x_range)
pdf_values_2 = 0.14 * pdf.(d2, x_range)
plot!(
x_range,
pdf_values_1;
label="d1",
color=:red,
linestyle=:dash
)
plot!(
x_range,
pdf_values_2;
label="d2",
color=:orange,
linestyle=:dash
)
This dataset is a mixture of two Gaussians, one centered at 0 and the other centered at 5, weighted 6:1. Now we can estimate the parameters of this more interesting mixture distribution. I will just plot the means here.3
chain = sample(two_gaussian_mixture(x), NUTS(), 1000, progress=true)
plot(chain[[:μ1, :μ2]]; size = (600, 400), fmt=:svg, margin=(8, :mm))
We can use the describe function to get a summary of the posterior distribution for each parameter. For example, to get a summary for the mean of the first Gaussian component:
describe(chain[:μ1])
Summary Stats:
Length: 1000
Missing Count: 0
Mean: 4.662154
Std. Deviation: 0.614489
Minimum: 2.359841
1st Quartile: 4.344250
Median: 4.745427
3rd Quartile: 5.105364
Maximum: 5.942638
Type: Float64
And with that, you should have some minimal basics to start building your own probabilistic models in Turing and plotting fuzzy caterpillars. 🐛
-
Mostly for future me when I come back to this in ~6mo, but maybe for others as well. 😄 ↩︎
-
MCMC is a class of algorithms used to sample from probability distributions, particularly when the distribution is complex and cannot be sampled from directly. The NUTS algorithm is an efficient MCMC method that automatically tunes its parameters to explore the posterior distribution effectively. For more details on MCMC, look here: https://en.wikipedia.org/wiki/Markov_chain_Monte_Carlo ↩︎
-
You might notice that the numbering of the distributions is flipped relative to our theoretical model. This is actually related to a “label switching” problem that can occur in mixture models, where the labels of the components can be arbitrary and may switch during sampling. ↩︎