Mcx

screenshot of Mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.

Overview

MCX is a probabilistic programming library focused on sampling methods. It transforms model definitions to generate logpdf or sampling functions, JIT-compiled with JAX, supporting batching on CPU, GPU, or TPU. It aims to offer sequential inference and performant sampling methods for Bayesian deep learning.

Features

  • Laser-focus on sampling methods
  • JIT-compiled with JAX for GPU and TPU support
  • Modular and re-usable model definitions
  • Batch sampling runtime with interactive mode for real-time monitoring
  • Support for neural network layers, stochastic support, and causal inference tools