import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap, lax
import diffrax
import equinox as eqx
from . import constants as cnst
from . import ABCMBTools as tools
import os
file_dir = os.path.dirname(__file__)
jax.config.update("jax_enable_x64", True)
"""
Cosmological perturbation evolution module.
Integrates linear perturbation equations for scalar modes across
cosmic time using background cosmology and species interactions.
"""
[docs]
class PerturbationEvolver(eqx.Module):
"""
Linear scalar perturbation evolution solver.
Evolves perturbations for all fluid species using Einstein-Boltzmann
equations in synchronous gauge.
Attributes:
-----------
species_list : tuple
A list of all fluids in the cosmology
species_dict : dict
A dictionary containing the names of all fluids, in the same order as
they appear in species_list.
k_axis_perturbations : jnp.array
A list of wavenumbers k at which to compute perturbations
specs : dict
A dictionary containing run options
adjoint : diffrax.adjoint
Adjoint mode for diffrax solves. Default is ForwardMode.
Methods:
--------
full_evolution : Evolve perturbations for multiple k modes
evolution_one_k : Evolve perturbations for single k mode
get_tca_on_off : Determine tight coupling approximation times
initial_conditions_one_k : Compute initial perturbation conditions
get_derivatives : Compute perturbation time derivatives
make_output_table : Create interpolatable perturbation table
"""
species_list : tuple
species_dict : dict
k_axis_perturbations : jnp.array
specs : dict
adjoint : "diffrax.adjoint" = eqx.field(static=True)
def __init__(
self,
species_list,
species_dict,
k_axis_perturbations=jnp.geomspace(1.e-4, 0.4, 600),
specs = {},
adjoint = diffrax.ForwardMode,
):
self.species_list = species_list
self.species_dict = species_dict
self.k_axis_perturbations = k_axis_perturbations
self.specs = specs
self.adjoint = adjoint
[docs]
def full_evolution(self, args):
"""
Evolve perturbations for multiple wavenumber modes.
Integrates perturbation equations for a range of k modes,
then interpolates results onto common time grid.
Parameters:
-----------
k : jnp.array
1D axis of wavenumbers k. Perturbations are computed and stored at these values.
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
PerturbationTable
Interpolatable table of perturbation evolution
Notes:
------
Uses logarithmic k spacing from 10^-4 to ~0.5 Mpc^-1 with 100 points.
Time integration runs from early times to z=1 (lna=-ln(2)).
"""
BG, params = args
lna = jnp.linspace(BG.lna_transfer_start, 0., 500)
# This scan function is only used if on CPU.
# For GPUs we vmap over the wavenumbers instead
def scan_fun(_, ki):
# evolution_one_k returns shape (Nlna, Ny)
y = self.evolution_one_k(ki, lna, args)
return None, y
if jax.default_backend() =='gpu':
res = vmap(self.evolution_one_k,in_axes=[0,None,None])(self.k_axis_perturbations, lna, args)
else:
_, res = lax.scan(scan_fun, None, self.k_axis_perturbations) # res has shape (Nk, Nlna, Ny)
res = res.transpose(2, 1, 0) # Transpose so the shape is (Ny, Nlna, Nk), easier for vmapping over in PT
PT = self.make_output_table(lna, res, args)
return PT
[docs]
def get_starting_time(self, k, args):
"""
Determine tight coupling approximation time range.
Finds start and end times for tight coupling between photons and baryons
by computing when Thomson scattering becomes ineffective relative to
Hubble and horizon crossing time scales.
Parameters:
-----------
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
tuple
(lna_start, lna_end) for tight coupling period
Notes:
------
Uses thresholds: τc/τh < 0.0015 (start), τh/τk < 0.07 (start),
τc/τh > 0.015 (end), τc/τk > 0.01 (end).
"""
BG, params = args
# 1) Starting lna
lna_start_range = jnp.linspace(-20.0, -10.0, 10000)
# a) τc/τh → f1(lna) = BG.tau_c * BG.aH
f1 = BG.tau_c(lna_start_range, params) * BG.aH(lna_start_range, params)
# invert f1(lna) = thr1 → lna = interp(thr1, f1, lna_range)
lna1 = jnp.interp(self.specs["R_tc"], f1, lna_start_range) # jnp.interp ends up being
# faster than fast_interp through here
# b) τh/τk → f2(lna) = k / BG.aH
f2 = k / BG.aH(lna_start_range, params)
# invert f2(lna) = thr2
lna2 = jnp.interp(self.specs["R_large"], f2, lna_start_range)
lna_ini = jnp.minimum(lna1, lna2)
return lna_ini
[docs]
def initial_conditions_one_k(self, k, lna_ini, args):
"""
Compute initial conditions for perturbation evolution.
Sets up initial values for metric and fluid perturbations at early times
using adiabatic initial conditions.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
lna_ini : float
Initial logarithm of scale factor
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
array
Initial perturbation state vector
Notes:
------
Uses CLASS-style initial conditions with metric perturbations h and η.
Assumes adiabatic initial conditions with vanishing isocurvature modes.
"""
BG, params = args
### CLASS Initial Conditions ###
a = jnp.exp(lna_ini)
tau_ini = BG.tau(lna_ini)
om = params["om"]
metric_eta_ini = (1.-k**2*tau_ini**2/12./(15.+4.*params['R_nu'])*(5.+4.*params['R_nu'] - (16.*params['R_nu']*params['R_nu']+280.*params['R_nu']+325)/10./(2.*params['R_nu']+15.)*tau_ini*om))
all_fluid_ini = jnp.concatenate([p.y_ini(k, tau_ini, params) for p in self.species_list])
y_ini = jnp.concatenate((jnp.array([metric_eta_ini]), all_fluid_ini))
return y_ini
[docs]
def get_derivatives(self, lna, y, args):
"""
Compute time derivatives for perturbation evolution.
Assembles the full system of Einstein-Boltzmann equations for
metric and fluid perturbations in synchronous gauge.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Current perturbation state vector
args : tuple
Wavenumber k and background cosmology (k, BG, params)
Returns:
--------
array
Time derivatives of perturbation state
"""
k, BG, params = args
a = jnp.exp(lna)
aH = BG.aH(lna, params)
metric_eta = y[0]
# Metric perturbation derivatives
sum_rho_delta = 0.
sum_rho_plus_P_theta = 0.
for i in range(len(self.species_list)):
species = self.species_list[i]
# If species has density perturbation, add to total.
sum_rho_delta += species.rho_delta(lna, y, params)
# If species has velocity perturbation, add to total.
sum_rho_plus_P_theta += species.rho_plus_P_theta(lna, y, params)
metric_h_prime = 2./aH**2 * (k**2*metric_eta + 4.*jnp.pi*cnst.G*a**2/cnst.c_Mpc_over_s**2 * sum_rho_delta)
metric_eta_prime = 4.*jnp.pi*cnst.G*a**2/aH/k**2 * sum_rho_plus_P_theta / cnst.c_Mpc_over_s**2
# Now loop over all species and assemble their respective y_primes
args = (BG, params, self.species_list, self.species_dict)
y_prime = jnp.array([metric_eta_prime])
for i in range(len(self.species_list)):
species = self.species_list[i]
y_prime = jnp.concatenate((y_prime, species.y_prime(k, lna, metric_h_prime, metric_eta_prime, y, args)))
return y_prime
[docs]
def evolution_one_k(self, k, lna, args):
"""
Evolve perturbations for single wavenumber mode.
Integrates Einstein-Boltzmann equations from early times through
recombination to late times using adaptive time stepping.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
lna : array
Logarithm of scale factor grid for output
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
diffrax.Solution
Dense solution object for interpolation
"""
### DIFFRAX INTEGRATION ###
lna_start = self.get_starting_time(k, args) # Start and end times from tight coupling settings
lna_end = 0.0
# For small k's the superhorizon time can be set relatively late, but we impose a cutoff of z~20000 for all modes
# at the very least.
lna_start = jnp.minimum(lna_start, -10.)
# Initial conditions for tight coupling
y_ini = self.initial_conditions_one_k(k, lna_start, args)
# Settings for post-tight coupling
term = diffrax.ODETerm(self.get_derivatives)
solver = diffrax.Kvaerno5()
rtol=jnp.where(
k > self.specs["k_split_PE"],
self.specs["rtol_large_k_PE"],
self.specs["rtol_small_k_PE"]
)
atol=jnp.where(
k > self.specs["k_split_PE"],
self.specs["atol_large_k_PE"],
self.specs["atol_small_k_PE"]
)
stepsize_controller = diffrax.PIDController(pcoeff=self.specs["pcoeff_PE"], icoeff=self.specs["icoeff_PE"], dcoeff=self.specs["dcoeff_PE"], rtol=rtol, atol=atol)
saveat = diffrax.SaveAt(ts=lna)
adjoint=self.adjoint()
sol = diffrax.diffeqsolve(
term, solver,
t0=lna_start, t1=lna_end, dt0=1.e-2, y0=y_ini,
stepsize_controller=stepsize_controller,
max_steps=self.specs["max_steps_PE"],
saveat=saveat,
args=(k,*args),
adjoint=adjoint
)
### END OF DIFFRAX INTEGRATION ###
return sol.ys
[docs]
def make_output_table(self, lna, modes, args):
"""
Create interpolatable perturbation table from evolution results.
Extracts key perturbation modes and computes derived quantities.
Parameters:
-----------
lna : array
Logarithm of scale factor grid
modes : array
Perturbation evolution results
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
PerturbationTable
Organized perturbation data for interpolation
"""
k = self.k_axis_perturbations
BG, params = args
metric_eta = modes[0]
# Build per-species perturbation dicts first; theta_b_prime draws from them.
species_perturbations = {
s.name: s.output_perturbations(lna, modes, (BG, params))
for s in self.species_list
}
# Baryon velocity derivative — backward-calculated from the Boltzmann equations.
# Requires Baryon and Photon objects for cs2 and the photon-baryon coupling R.
Baryon = self.species_list[self.species_dict["Baryon"]]
Photon = self.species_list[self.species_dict["Photon"]]
delta_b = species_perturbations["Baryon"]["delta"]
theta_b = species_perturbations["Baryon"]["theta"]
theta_g = species_perturbations["Photon"]["theta"]
karr = k[None, :]
a = jnp.exp(lna)[:, None]
aH = BG.aH(lna, params)[:, None]
cs2 = Baryon.cs2(lna, (BG, params, self.species_list, self.species_dict))[:, None]
R = 4.*Photon.rho(lna, params)[:, None]/3./Baryon.rho(lna, params)[:, None]
tau_c = BG.tau_c(lna, params)[:, None]
theta_b_prime = -theta_b + cs2/aH*(karr**2*delta_b) + R/aH/tau_c*(theta_g-theta_b)
# Sum density/velocity/shear over all species for metric derivatives and delta_m.
sum_rho_delta = jnp.zeros_like(modes[0])
sum_rho_plus_P_theta = jnp.zeros_like(modes[0])
sum_rho_plus_P_sigma = jnp.zeros_like(modes[0])
sum_rho_delta_m = jnp.zeros_like(modes[0])
sum_rho_m = 0.
for s in self.species_list:
if s.num_equations > 0:
rho_delta = vmap(s.rho_delta, in_axes=(0, 1, None))(lna, modes, params)
sum_rho_delta += rho_delta
sum_rho_plus_P_theta += vmap(s.rho_plus_P_theta, in_axes=(0, 1, None))(lna, modes, params)
sum_rho_plus_P_sigma += vmap(s.rho_plus_P_sigma, in_axes=(0, 1, None))(lna, modes, params)
if s.is_matter:
sum_rho_delta_m += rho_delta
sum_rho_m += s.rho(lna, params)
delta_m = sum_rho_delta_m / sum_rho_m[:, None]
metric_h_prime = 2./aH**2 * (karr**2*metric_eta + 4.*jnp.pi*cnst.G*a**2/cnst.c_Mpc_over_s**2 * sum_rho_delta)
metric_eta_prime = 4.*jnp.pi*cnst.G*a**2/aH * sum_rho_plus_P_theta / cnst.c_Mpc_over_s**2 / karr**2
metric_alpha = aH*(metric_h_prime + 6.*metric_eta_prime)/2./karr**2
metric_alpha_prime = metric_eta/aH - 2.*metric_alpha \
- 12.*jnp.pi*cnst.G*a**2/aH * sum_rho_plus_P_sigma / cnst.c_Mpc_over_s**2 / karr**2
return PerturbationTable(
k,
lna,
delta_m,
theta_b_prime,
metric_eta,
metric_h_prime,
metric_eta_prime,
metric_alpha,
metric_alpha_prime,
species_perturbations,
)
[docs]
class PerturbationTable(eqx.Module):
"""
Interpolatable table of perturbation evolution.
Stores perturbation modes as 2D arrays over wavenumber and time
for efficient interpolation. Per-species perturbations in physically
meaningful form are accessible via species_perturbations.
Attributes:
-----------
k : array
Wavenumber grid (units: Mpc^{-1})
lna : array
Logarithm of scale factor grid
delta_m : array
Total matter density perturbation, weighted sum over all matter species
theta_b_prime : array
Baryon velocity derivative (backward-calculated from Boltzmann equations)
metric_eta : array
Metric perturbation η
metric_h_prime : array
Time derivative of metric h
metric_eta_prime : array
Time derivative of metric η
metric_alpha : array
Derived metric perturbation α
metric_alpha_prime : array
Time derivative of metric α
species_perturbations : dict
Named perturbation arrays for each species, keyed by species name.
Each value is a dict {quantity: array(Nlna, Nk)}.
Species with no perturbations (e.g. dark energy) map to {}.
"""
k : jnp.array
lna : jnp.array
delta_m : jnp.array
theta_b_prime : jnp.array
metric_eta : jnp.array
metric_h_prime : jnp.array
metric_eta_prime : jnp.array
metric_alpha : jnp.array
metric_alpha_prime : jnp.array
species_perturbations : dict