Repository logo

Infoscience

  • English
  • French
Log In
Logo EPFL, École polytechnique fédérale de Lausanne

Infoscience

  • English
  • French
Log In
  1. Home
  2. Academic and Research Output
  3. Journal articles
  4. JAX-COSMO: AN END-TO-END DIFFERENTIABLE AND GPU ACCELERATED COSMOLOGY LIBRARY
 
research article

JAX-COSMO: AN END-TO-END DIFFERENTIABLE AND GPU ACCELERATED COSMOLOGY LIBRARY

Campagne, J. E.
•
Lanusse, F.
•
Zuntz, J.
Show more
2023
Open Journal of Astrophysics

We present jax-cosmo, a library for automatically differentiable cosmological theory calculations. jax-cosmo uses the JAX library, which has created a new coding ecosystem, especially in probabilistic programming. As well as batch acceleration, just-in-time compilation, and automatic optimization of code for different hardware modalities (CPU, GPU, TPU), JAX exposes an automatic differentiation (autodiff) mechanism. Thanks to autodiff, jax-cosmo gives access to the derivatives of cosmological likelihoods with respect to any of their parameters, and thus enables a range of powerful Bayesian inference algorithms, otherwise impractical in cosmology, such as Hamiltonian Monte Carlo and Variational Inference. In its initial release, jax-cosmo implements background evolution, linear and non-linear power spectra (using halofit or the Eisenstein and Hu transfer function), as well as angular power spectra (Cℓ) with the Limber approximation for galaxy and weak lensing probes, all differentiable with respect to the cosmological parameters and their other inputs. We illustrate how automatic differentiation can be a game-changer for common tasks involving Fisher matrix computa-tions, or full posterior inference with gradient-based techniques (e.g. Hamiltonian Monte Carlo). In particular, we show how Fisher matrices are now fast, exact, no longer require any fine tuning, and are themselves differentiable with respect to parameters of the likelihood, enabling complex survey optimization by simple gradient descent. Finally, using a Dark Energy Survey Year 1 3x2pt analysis as a benchmark, we demonstrate how jax-cosmo can be combined with Probabilistic Programming Languages such as NumPyro to perform posterior inference with state-of-the-art algorithms including a No U-Turn Sampler (NUTS), Automatic Differentiation Variational Inference (ADVI), and Neural Transport HMC (NeuTra). We show that the effective sample size per node (1 GPU or 32 CPUs) per hour of wall time is about 5 times better for a JAX NUTS sampler compared to the well optimized Cobaya Metropolis-Hasting sampler. We further demonstrate that Normalizing Flows using Neural Transport are a promising methodology for model validation in the early stages of analysis.

  • Files
  • Details
  • Metrics
Loading...
Thumbnail Image
Name

10.21105_astro.2302.05163.pdf

Type

Main Document

Version

Published version

Access type

openaccess

License Condition

CC BY

Size

0 B

Format

Adobe PDF

Checksum (MD5)

d41d8cd98f00b204e9800998ecf8427e

Logo EPFL, École polytechnique fédérale de Lausanne
  • Contact
  • infoscience@epfl.ch

  • Follow us on Facebook
  • Follow us on Instagram
  • Follow us on LinkedIn
  • Follow us on X
  • Follow us on Youtube
AccessibilityLegal noticePrivacy policyCookie settingsEnd User AgreementGet helpFeedback

Infoscience is a service managed and provided by the Library and IT Services of EPFL. © EPFL, tous droits réservés