import numpy as np
import jax.numpy as jnp
import equinox as eqx
import jax
from jax import vmap, jit, config, grad, lax
from diffrax import diffeqsolve, ODETerm, Dopri5, Kvaerno3, Kvaerno5, Tsit5, SaveAt, PIDController, DiscreteTerminatingEvent
from jax.scipy.interpolate import RegularGridInterpolator
from functools import partial
from interpax import CubicSpline
from scipy.special import spherical_jn
from . import ABCMBTools as tools
from . import constants as cnst
import os
file_dir = os.path.dirname(__file__)
config.update("jax_enable_x64", True)
# 2D arrays of tabulated spherical functions over l and x axes.
bessel_l_tab = jnp.array(np.loadtxt(file_dir+"/bessel_tab/l.txt"), dtype="int")
xphi0_tab = jnp.array(np.loadtxt(file_dir+"/bessel_tab/xphi0.txt"))
phi0_tab = jnp.array(np.loadtxt(file_dir+"/bessel_tab/phi0.txt"))
xphi1_tab = jnp.array(np.loadtxt(file_dir+"/bessel_tab/xphi1.txt"))
phi1_tab = jnp.array(np.loadtxt(file_dir+"/bessel_tab/phi1.txt"))
xphi2_tab = jnp.array(np.loadtxt(file_dir+"/bessel_tab/xphi2.txt"))
phi2_tab = jnp.array(np.loadtxt(file_dir+"/bessel_tab/phi2.txt"))
try:
gpus = jax.devices('gpu')
bessel_l_tab = jax.device_put(
bessel_l_tab, device=gpus[0])
xphi0_tab = jax.device_put(
xphi0_tab, device=gpus[0])
phi0_tab = jax.device_put(
phi0_tab, device=gpus[0])
xphi1_tab = jax.device_put(
xphi1_tab, device=gpus[0])
phi1_tab = jax.device_put(
phi1_tab, device=gpus[0])
xphi2_tab = jax.device_put(
xphi2_tab, device=gpus[0])
phi2_tab = jax.device_put(
phi2_tab, device=gpus[0])
except:
pass
# large-x asymptotic expansion of spherical bessel functions
Q = lambda l, x : jnp.sqrt(x**2-l**2) - l*jnp.pi/2 + l * jnp.arcsin(l/x)
J = lambda l, x : jnp.sqrt(2/jnp.pi/jnp.sqrt(x**2-l**2)) * jnp.cos(Q(l, x) - jnp.pi/4)
j = lambda l, x : jnp.sqrt(jnp.pi/2/x) * J(l+1/2, x)
[docs]
def phi0(i, x):
"""
New method for computing phi0 (or just jl)
We tabulated the bessel function between its smallest value (~1.e-10) out to the fifth local maximum.
This is a different interval for each l, but we kept identical shape so it can be a large 2D array.
If the incoming argument is within this interval, we use fast_interp. Otherwise we use the large x expansion above.
"""
l = bessel_l_tab[i]
# use x_safe to avoid double-where gotcha in reverse AD
x_safe = jnp.where(x >= xphi0_tab[-1, i], x, xphi0_tab[-1, i])
return jnp.where(
x < xphi0_tab[0, i],
0.,
jnp.where(
x >= xphi0_tab[-1, i],
j(l, x_safe),
tools.fast_interp(x, xphi0_tab[:, i].min(), xphi0_tab[:, i].max(), phi0_tab[:, i])
)
)
[docs]
def phi1(i, x):
"""
New method for computing phi1, or jl'.
We tabulated the bessel function between its smallest value (~1.e-10) out to the fifth local maximum.
This is a different interval for each l, but we kept identical shape so it can be a large 2D array.
If the incoming argument is within this interval, we use fast_interp. Otherwise we use the large x expansion above.
"""
l = bessel_l_tab[i]
x_safe = jnp.where(x >= xphi1_tab[-1, i], x, xphi1_tab[-1, i])
return jnp.where(
x < xphi1_tab[0, i],
0.,
jnp.where(
x >= xphi1_tab[-1, i],
l/x_safe*j(l, x_safe) - j(l+1, x_safe),
tools.fast_interp(x, xphi1_tab[:, i].min(), xphi1_tab[:, i].max(), phi1_tab[:, i])
)
)
[docs]
def phi2(i, x):
"""
New method for computing phi2 = (3 jl'' + jl)/2
We tabulated the bessel function between its smallest value (~1.e-10) out to the fifth local maximum.
This is a different interval for each l, but we kept identical shape so it can be a large 2D array.
If the incoming argument is within this interval, we use fast_interp. Otherwise we use the large x expansion above.
"""
l = bessel_l_tab[i]
x_safe = jnp.where(x >= xphi2_tab[-1, i], x, xphi2_tab[-1, i])
return jnp.where(
x < xphi2_tab[0, i],
0.,
jnp.where(
x >= xphi2_tab[-1, i],
((3*l*(l-1)-2*x_safe**2)*j(l, x_safe)+6*x_safe*j(l+1, x_safe))/2/x_safe**2,
tools.fast_interp(x, xphi2_tab[:, i].min(), xphi2_tab[:, i].max(), phi2_tab[:, i])
)
)
[docs]
class SpectrumSolver(eqx.Module):
"""
CMB angular power spectrum computation.
Computes temperature and polarization angular power spectra by
integrating transfer functions over wavenumber and time.
Attributes:
-----------
ells : jnp.array
Multipole values for output power spectra
ells_indices : jnp.array
Indices into bessel_l_tab corresponding to ells
lensing_ells : jnp.array
Extended multipole range for lensing calculations
lensing_ells_indices : jnp.array
Indices into bessel_l_tab for lensing multipoles
lensing_mus : jnp.array
Used for lensing, the Gauss-Legendre quadrature roots for the correlation function -> Cl integral.
lensing_ws : jnp.array
Used for lensing, the Gauss-Legendre quadrature weights for the correlation function -> Cl integral.
lensing : bool
Whether to include gravitational lensing effects
k_axis_transfer : jnp.array
Wavenumber grid for transfer function integration (units: Mpc^{-1})
k_axis_Pk_output : jnp.array
Wavenumber grid for matter power spectrum output (units: Mpc^{-1})
k_pivot : float
Pivot scale for primordial power spectrum normalization (units: Mpc^{-1}, default: 0.05)
scale_sw : float
Multiplicative factor for Sachs-Wolfe term (default: 1.0)
scale_isw : float
Multiplicative factor for integrated Sachs-Wolfe term (default: 1.0)
scale_dop : float
Multiplicative factor for Doppler term (default: 1.0)
scale_pol : float
Multiplicative factor for polarization term (default: 1.0)
Methods:
--------
primordial_spectrum : Compute primordial power spectrum
Pk_lin : Compute linear matter power spectrum
get_Cl : Compute angular power spectra for multiple ℓ
Cl_one_ell : Compute angular power spectrum for single ℓ
integrand_T0 : Compute SW+ISW temperature source integrand
integrand_T1 : Compute ISW temperature source integrand
integrand_T2 : Compute polarization temperature source integrand
integrand_E : Compute E-mode polarization source integrand
"""
ells : jnp.array
ells_indices : jnp.array
lensing_ells : jnp.array
lensing_ells_indices : jnp.array
lensing_mus : jnp.array
lensing_ws : jnp.array
lensing : bool
k_axis_transfer : jnp.array
k_axis_Pk_output : jnp.array
k_pivot : float = 0.05 # In 1/Mpc
scale_sw : float = 1.
scale_isw : float = 1.
scale_dop : float = 1.
scale_pol : float = 1.
def __init__(self,
ellmin=2,
ellmax=2500,
lensing=False,
k_axis_transfer=jnp.geomspace(1.e-4, 0.4, 2500),
k_axis_Pk_output=jnp.geomspace(1.e-4, 0.1, 100),
k_pivot=0.05,
scale_sw=1,
scale_isw=1,
scale_dop=1,
scale_pol=1):
"""
Initialize CMB spectrum solver.
Sets up multipole range, lensing configuration, and source term switches
for computing angular power spectra.
Parameters:
-----------
ellmin : int, optional
Minimum multipole (default: 2)
ellmax : int, optional
Maximum multipole (default: 2500)
lensing : bool, optional
Whether to include lensing effects (default: True)
k_pivot : float, optional
Pivot scale for primordial spectrum (units: Mpc^{-1}, default: 0.05)
scale_sw : float, optional
Switch for Sachs-Wolfe term (default: 1)
scale_isw : float, optional
Switch for integrated Sachs-Wolfe term (default: 1)
scale_dop : float, optional
Switch for Doppler term (default: 1)
scale_pol : float, optional
Switch for polarization term (default: 1)
"""
self.lensing = lensing
self.ells = jnp.arange(ellmin, ellmax+1)
ell_idx_min = jnp.where(bessel_l_tab<=ellmin)[0][-1]
ell_idx_max = jnp.where(bessel_l_tab>=ellmax)[0][0]
self.ells_indices = jnp.arange(ell_idx_min, ell_idx_max+1)
if self.lensing:
lensing_ellmax = ellmax+500
lensing_ell_idx_max = jnp.where(bessel_l_tab>=lensing_ellmax)[0][0]
self.lensing_ells = jnp.arange(ellmin, lensing_ellmax+1)
self.lensing_ells_indices = jnp.arange(ell_idx_min, lensing_ell_idx_max+1)
#self.lensing_theta = jnp.linspace(0., jnp.pi/16., lensing_ellmax // 8) # Size recommended by CLASS
num_mu = lensing_ellmax + 70
mu, w = tools.gauss_legendre_weights(num_mu)
self.lensing_mus = jnp.concatenate((mu, jnp.array([1.])))
self.lensing_ws = jnp.concatenate((w, jnp.array([0.])))
else:
self.lensing_ells = self.ells
self.lensing_ells_indices = self.ells_indices
#self.lensing_theta = jnp.array([0.]) # Not needed
self.lensing_mus = jnp.array([0.]) # Not needed
self.lensing_ws = jnp.array([0.]) # Not needed
self.k_axis_transfer = k_axis_transfer
self.k_axis_Pk_output = k_axis_Pk_output
self.k_pivot = k_pivot
self.scale_sw = scale_sw
self.scale_isw = scale_isw
self.scale_dop = scale_dop
self.scale_pol = scale_pol
[docs]
def primordial_spectrum(self, k, params):
"""
Compute primordial curvature power spectrum.
Parameters:
-----------
k : float or array
Wavenumber (units: Mpc^{-1})
params : dict
Dictionary of input and derived parameters
Returns:
--------
float or array
Primordial power spectrum P_R(k), units Mpc^3
"""
return params['A_s']*(k/self.k_pivot)**(params['n_s']-1.) * (2*jnp.pi**2/k**3)
[docs]
def Pk_lin(self, k, z, PT, params):
"""
Compute linear matter power spectrum at wavenumbers k and redshift z.
Parameters:
-----------
k : float or array
Wavenumber (Mpc^{-1})
z : float
Redshift to evaluate.
PT : perturbations.PerturbationTable
Perturbation evolution table
params : dict
Dictionary of input and derived parameters
Returns:
--------
float or array
Linear matter power spectrum P(k, z), units Mpc^3
"""
lna = -jnp.log(1.+z)
# vmapped interpolation over Nk (columns of the 2D arrays)
interp_over_lna = jax.vmap(
lambda y: jnp.interp(lna, PT.lna, y),
in_axes=1 # loop over columns
)
delta_m_lna = interp_over_lna(PT.delta_m) # shape (Nk,)
# now interpolate over k
delta_m = jnp.interp(k, PT.k, delta_m_lna)
return delta_m**2 * self.primordial_spectrum(k, params)
[docs]
def Pk_cb(self, k, z, PT, params):
"""
Compute linear Baryon+DarkMatter power spectrum at wavenumbers k and redshift z.
Does not include any other massive species present.
Parameters:
-----------
k : float or array
Wavenumber (Mpc^{-1})
z : float
Redshift to evaluate.
PT : perturbations.PerturbationTable
Perturbation evolution table
params : dict
Dictionary of input and derived parameters
Returns:
--------
float or array
Linear Baryon+DarkMatter power spectrum P_cb(k, z), units Mpc^3
"""
lna = -jnp.log(1.+z)
# vmapped interpolation over Nk (columns of the 2D arrays)
interp_over_lna = jax.vmap(
lambda y: jnp.interp(lna, PT.lna, y),
in_axes=1 # loop over columns
)
delta_dm_lna = interp_over_lna(PT.species_perturbations["ColdDarkMatter"]["delta"])
delta_b_lna = interp_over_lna(PT.species_perturbations["Baryon"]["delta"])
# now interpolate over k
delta_dm = jnp.interp(k, PT.k, delta_dm_lna)
delta_b = jnp.interp(k, PT.k, delta_b_lna)
# total matter overdensity
delta_m = (
params['omega_b'] * delta_b +
params['omega_cdm'] * delta_dm
) / params['omega_m']
return delta_m**2 * self.primordial_spectrum(k, params)
[docs]
def lensing_power_spectrum(self, k, lna, PT, BG, params):
"""
Computes the lensing power spectrum at wavenumbers k and redshift z.
Eq.(3.15) in astro-ph/0601594
Parameters:
-----------
k : float or array
Wavenumber (Mpc^{-1})
lna : float
Scale factor
PT : perturbations.PerturbationTable
Perturbation evolution table
BG : background.Background
Background cosmology module
params : dict
Dictionary of input and derived parameters
Returns:
--------
float or array
Lensing matter power spectrum P(k, z), dimensionless.
"""
a = jnp.exp(lna)
z = 1./a - 1.
aH = BG.aH(lna, params)
Omega_m = params["omega_m"]/params["h"]**2
Omega_L = params["omega_Lambda"]/params["h"]**2
# Matter fraction over time after equality. 1 at early times and becomes Om0 today.
Om = (Omega_m * (1.+z)**3)/ ((Omega_m * (1.+z)**3) + Omega_L)
Pk = self.Pk_lin(k, z, PT, params) # Mpc^3
return 9./8./jnp.pi**2 * Om**2 * aH**4 * Pk / k
[docs]
def lensing_Cl(self, ells, PT, BG, params):
"""
Angular lensing power spectrum at multipole ell.
IMPORTANT: Assumes Limber approximation throughout, even at ell=2.
Eq.(3.14) in astro-ph/0601594, except shifts ell -> ell+1/2 to match CLASS.
Parameters:
-----------
ell : float or array
Multipole
PT : perturbations.PerturbationTable
Perturbation evolution table
BG : background.Background
Background cosmology module
params : dict
Dictionary of input and derived parameters
Returns:
--------
float or array
Angular lensing matter power spectrum Cl^phiphi, dimensionless.
"""
coeff = 8.*jnp.pi**2/(ells+0.5)**3
chi = lambda lna : BG.tau0 - BG.tau(lna)
# The previous jnp.nan_to_num(integrand, nan=0.) here masked the
# forward NaN but left a 0*NaN cotangent in the backward through
# the where-mask that nan_to_num secretly expands to, which
# propagated through BG.tau. Fix: substitute lna_safe everywhere,
# then mask the result to 0 at the boundary.
lna_axis = jnp.linspace(BG.lna_rec, 0., 500)
lna_floor = lna_axis[-2]
def integrand_func(lna):
lna_safe = jnp.where(lna < 0., lna, lna_floor)
chi_safe = chi(lna_safe)
k = (ells+0.5)/chi_safe
window = (chi(BG.lna_rec) - chi_safe)/chi(BG.lna_rec)/chi_safe
res = (
chi_safe / BG.aH(lna_safe, params)
* window**2
* self.lensing_power_spectrum(k, lna_safe, PT, BG, params)
)
return jnp.where(lna < 0., res, 0.)
integrand = vmap(integrand_func)(lna_axis)
return coeff*jnp.trapezoid(integrand, lna_axis, axis=0)
[docs]
def lensed_Cls(self, ells, ClTT_unlensed, ClTE_unlensed, ClEE_unlensed, PT, BG, params):
"""
Compute lensed CMB power spectra.
Applies gravitational lensing corrections to unlensed temperature
and polarization power spectra using Wigner rotation matrices.
Parameters:
-----------
ells : array
Multipole values
ClTT_unlensed : array
Unlensed temperature power spectrum
ClTE_unlensed : array
Unlensed temperature-E-mode cross spectrum
ClEE_unlensed : array
Unlensed E-mode polarization power spectrum
PT : perturbations.PerturbationTable
Perturbation evolution table
BG : background.Background
Background cosmology module
params : dict
Dictionary of input and derived parameters
Returns:
--------
tuple
(ClTT, ClTE, ClEE) lensed power spectra
"""
# CLASS samples angle uniformly
# 500 points is enough for lmax < 4000
#theta = jnp.linspace(0., jnp.pi/16., 500)
# Flip mu so that mu is in ascending order, works better for trapz.
#mu = jnp.flip(jnp.cos(self.lensing_theta))
mu = self.lensing_mus
# Compute lensing Cl
Clpp = self.lensing_Cl(ells, PT, BG, params)
# Wigner matrices needed in general and for temperature
# Note that for all wigner matrices, the symmetry relation is dnm = (-1)^(m-n) x dmn
d00 = tools.d00(mu, ells)
d11 = tools.d1n(mu, ells, 1)
d1m1 = tools.d1n(mu, ells, -1)
d2m2 = tools.d2n(mu, ells, -2)
dm11 = d1m1
# Wigner matrices needed for polarization
d22 = tools.d2n(mu, ells, 2)
d31 = tools.d3n(mu, ells, 1)
d40 = tools.d4n(mu, ells, 0)
d3m3 = tools.d3n(mu, ells, -3)
d4m4 = tools.d4n(mu, ells, -4)
d20 = tools.d2n(mu, ells, 0)
d3m1 = tools.d3n(mu, ells, -1)
d4m2 = tools.d4n(mu, ells, -2)
d02 = d20
dm24 = d4m2
# Lensing angular correlation function
Cgl = 1./4./jnp.pi * jnp.sum(
(2.*ells+1)*ells*(ells+1)*Clpp*d11, axis=1
) # Nmu
Cgl2 = 1./4./jnp.pi * jnp.sum(
(2.*ells+1)*ells*(ells+1)*Clpp*dm11, axis=1
) # Nmu
sigma2 = Cgl[-1] - Cgl
Cgl = Cgl[:, None]
Cgl2 = Cgl2[:, None]
sigma2 = sigma2[:, None]
llp1 = ells*(ells+1)
X000 = jnp.exp(-llp1*sigma2/4)
X000_prime = -llp1/4.*X000
X220 = 1./4.*jnp.sqrt((ells+2)*(ells-1)*ells*(ells+1))*jnp.exp(-(llp1-2)*sigma2/4.)
X022 = jnp.exp(-(llp1-4)*sigma2/4)
X022_prime = -(llp1-4)/4*X022
X121 = -1./2.*jnp.sqrt((ells+2)*(ells-1))*jnp.exp(-(llp1-8./3.)*sigma2/4.)
X132 = -1./2.*jnp.sqrt((ells+3)*(ells-2))*jnp.exp(-(llp1-20./3.)*sigma2/4.)
X242 = 1./4.*jnp.sqrt((ells+4)*(ells+3)*(ells-2)*(ells-3))*jnp.exp(-(llp1-10.)*sigma2/4.)
# Correlation functions
ksi = 1./4./jnp.pi * jnp.sum(
(2.*ells+1)*ClTT_unlensed * (
X000**2 * d00 \
+ 8./ells/(ells+1)*Cgl2*X000_prime**2*d1m1 \
+ Cgl2**2 * (X000_prime**2*d00 + X220**2*d2m2) \
#- d00
),
axis=1
)
ksip = 1./4./jnp.pi * jnp.sum(
(2.*ells+1)*ClEE_unlensed * (
X022**2 * d22 \
+ 2*Cgl2*X132*X121*d31 \
+ Cgl2**2 * (X022_prime**2*d22 + X242*X220*d40) \
#- d22
),
axis=1
)
ksim = 1./4./jnp.pi * jnp.sum(
(2.*ells+1)*ClEE_unlensed * (
X022**2 * d2m2 \
+ Cgl2*(X121**2*d1m1 + X132**2*d3m3) \
+ 1./2.*Cgl2**2 * (2*X022_prime**2*d2m2 + X220**2*d00 + X242**2*d4m4) \
#- d2m2
),
axis=1
)
ksix = 1./4./jnp.pi * jnp.sum(
(2.*ells+1)*ClTE_unlensed * (
X022*X000*d02 \
+ Cgl2 * 2*X000_prime/jnp.sqrt(llp1) * (X121*d11 + X132*d3m1) \
+ 1./2.*Cgl2**2 * ((2*X022_prime*X000_prime+X220**2)*d20+X220*X242*dm24) \
#- d02
),
axis=1
)
#ClTT = 2.*jnp.pi * jnp.trapezoid(ksi[:, None]*d00, mu, axis=0) + ClTT_unlensed
#ClTE = 2.*jnp.pi * jnp.trapezoid(ksix[:, None]*d20, mu, axis=0) + ClTE_unlensed
#ClEE = 1./2. * 2.*jnp.pi * jnp.trapezoid(ksip[:, None]*d22+ksim[:, None]*d2m2, mu, axis=0) + ClEE_unlensed
w = self.lensing_ws[:, None]
ClTT = 2*jnp.pi * jnp.sum(ksi[:, None]*d00*w, axis=0)
ClTE = 2*jnp.pi * jnp.sum(ksix[:, None]*d20*w, axis=0)
ClEE = 1./2. * 2*jnp.pi * jnp.sum(
(ksip[:, None]*d22 + ksim[:, None]*d2m2)*w,
axis=0
)
return (ClTT, ClTE, ClEE)
[docs]
def get_Cl(self, PT, BG, params):
"""
Compute angular power spectra for multiple multipoles.
Parameters:
-----------
PT : perturbations.PerturbationTable
Perturbation evolution table
BG : background.Background
Background cosmology module
params : dict
Dictionary of input and derived parameters
Returns:
--------
tuple
(ClTT, ClTE, ClEE) angular power spectra
"""
tt_raw, te_raw, ee_raw = vmap(self.Cl_one_ell, in_axes=(0, None, None, None))(self.lensing_ells_indices, PT, BG, params)
# Cubic spline for smooth Cl over user requested ells
lensing_ells = bessel_l_tab[self.lensing_ells_indices]
tt_unlensed = CubicSpline(lensing_ells, tt_raw, check=False)(self.lensing_ells)
te_unlensed = CubicSpline(lensing_ells, te_raw, check=False)(self.lensing_ells)
ee_unlensed = CubicSpline(lensing_ells, ee_raw, check=False)(self.lensing_ells)
def get_lensed_Cls():
tt_lensed, te_lensed, ee_lensed = self.lensed_Cls(self.lensing_ells, tt_unlensed, te_unlensed, ee_unlensed, PT, BG, params)
return (tt_lensed[self.ells-2], te_lensed[self.ells-2], ee_lensed[self.ells-2])
def get_unlensed_Cls():
return (tt_unlensed[self.ells-2], te_unlensed[self.ells-2], ee_unlensed[self.ells-2])
return lax.cond(
self.lensing,
get_lensed_Cls,
get_unlensed_Cls
)
[docs]
def Cl_one_ell(self, idx, PT, BG, params):
"""
Computes angular power spectrum for single multipole.
Integrates transfer functions over wavenumber.
Parameters:
-----------
idx : int
Index into bessel_l_tab for multipole ℓ
PT : perturbations.PerturbationTable
Perturbation evolution table
BG : background.Background
Background cosmology module
params : dict
Dictionary of input and derived parameters
Returns:
--------
tuple
(C_ℓ^TT, C_ℓ^TE, C_ℓ^EE) angular power spectra
"""
l = bessel_l_tab[idx]
k_axis = self.k_axis_transfer
lna_axis = PT.lna[:-1]
delta_lna = PT.lna[-1] - PT.lna[-2]
### TRANSFER FUNCTION ###
# Background quantities, all Nlna 1D vectors
tau0 = BG.tau0
tau = BG.tau(lna_axis)
g = vmap(BG.visibility,in_axes=[0,None])(lna_axis, params)
g_prime = vmap(grad(BG.visibility,argnums=0),in_axes=[0,None])(lna_axis, params) # Derivative of g w.r.t. lna
aH = BG.aH(lna_axis, params)
expmkappa = vmap(BG.expmkappa)(lna_axis)
aH_dot = BG.aH_prime(lna_axis, params) * aH # Derivative of aH w.r.t. conformal time tau.
# Keep a 1D alias of aH for the rolling-accumulator scan below.
aH_1d = aH
g = g[:, None]
g_prime = g_prime[:, None]
aH = aH[:, None]
expmkappa = expmkappa[:, None]
aH_dot = aH_dot[:, None]
# Perturbations, all (Nlna, Nk) 2D vectors
# Cubic Spline is necessary here for accuracy.
interp_column = lambda col : CubicSpline(jnp.log10(PT.k), col, check=False)(jnp.log10(k_axis))
# Found that this is much much faster than RegularGridInterpolator
photon_sp = PT.species_perturbations["Photon"]
baryon_sp = PT.species_perturbations["Baryon"]
delta_g = vmap(interp_column, in_axes=0, out_axes=0)(photon_sp["delta"][:-1, :])
theta_b = vmap(interp_column, in_axes=0, out_axes=0)(baryon_sp["theta"][:-1, :])
theta_b_prime = vmap(interp_column, in_axes=0, out_axes=0)(PT.theta_b_prime[:-1, :])
sigma_g = vmap(interp_column, in_axes=0, out_axes=0)(photon_sp["sigma"][:-1, :])
Gg0 = vmap(interp_column, in_axes=0, out_axes=0)(photon_sp["G0"][:-1, :])
Gg2 = vmap(interp_column, in_axes=0, out_axes=0)(photon_sp["G2"][:-1, :])
eta = vmap(interp_column, in_axes=0, out_axes=0)(PT.metric_eta[:-1, :])
eta_prime = vmap(interp_column, in_axes=0, out_axes=0)(PT.metric_eta_prime[:-1, :])
alpha = vmap(interp_column, in_axes=0, out_axes=0)(PT.metric_alpha[:-1, :])
alpha_prime = vmap(interp_column, in_axes=0, out_axes=0)(PT.metric_alpha_prime[:-1, :])
# Source terms
sourceT0 = self.scale_sw * g * (delta_g/4. + aH*alpha_prime) \
+ self.scale_isw * (
g * (eta - aH*alpha_prime - 2.*aH*alpha) \
+ 2.*expmkappa * (aH*eta_prime - aH_dot*alpha - aH**2*alpha_prime)
) \
+ self.scale_dop * (
aH * (g*((theta_b_prime / k_axis**2) + alpha_prime) \
+ g_prime*((theta_b / k_axis**2) + alpha))
)
sourceT1 = self.scale_isw * expmkappa * \
((aH*alpha_prime + 2.*aH*alpha - eta) * k_axis)
sourceT2 = self.scale_pol * g * (2*sigma_g + Gg0 + Gg2) / 8.
sourceE = jnp.sqrt(6) * g * (2*sigma_g + Gg0 + Gg2) / 8.
# Here we perform the time integral to get transfer functions from source functions.
# previously, this block explicitly built a 2D (Nlna, Nk) tensor for each ell and summed it down to (Nk).
# This newer version refactors into four accumulators of shape (Nk). For each lna, we compute all four
# (Nk), multiply by a trapezoid weight, and then add to the accumulator. The result is identical but
# avoids having to construct a full 2D tensor for each ell, instead just constructing the 1D (Nk) tensor
# and accumulating down ell. Clever "traingle term" added by hand is now handled by the trapezoid weights.
# Pre-slice bessel-table columns so the scan body doesn't re-index
# ..._tab[:, idx] every iteration.
x0_min = xphi0_tab[0, idx]
x0_max = xphi0_tab[-1, idx]
x1_min = xphi1_tab[0, idx]
x1_max = xphi1_tab[-1, idx]
x2_min = xphi2_tab[0, idx]
x2_max = xphi2_tab[-1, idx]
col_phi0_l = phi0_tab[:, idx]
col_phi1_l = phi1_tab[:, idx]
col_phi2_l = phi2_tab[:, idx]
ell_eps_factor = jnp.sqrt(3./8.*(l+2)*(l+1)*l*(l-1))
def phi0_local(x):
x_safe = jnp.where(x >= x0_max, x, x0_max)
return jnp.where(
x < x0_min,
0.,
jnp.where(
x >= x0_max,
j(l, x_safe),
tools.fast_interp(x, x0_min, x0_max, col_phi0_l)
)
)
def phi1_local(x):
x_safe = jnp.where(x >= x1_max, x, x1_max)
return jnp.where(
x < x1_min,
0.,
jnp.where(
x >= x1_max,
l/x_safe*j(l, x_safe) - j(l+1, x_safe),
tools.fast_interp(x, x1_min, x1_max, col_phi1_l)
)
)
def phi2_local(x):
x_safe = jnp.where(x >= x2_max, x, x2_max)
return jnp.where(
x < x2_min,
0.,
jnp.where(
x >= x2_max,
((3*l*(l-1)-2*x_safe**2)*j(l, x_safe)+6*x_safe*j(l+1, x_safe))/2/x_safe**2,
tools.fast_interp(x, x2_min, x2_max, col_phi2_l)
)
)
Nlna = lna_axis.shape[0]
weights = jnp.full((Nlna,), delta_lna, dtype=sourceT0.dtype)
weights = weights.at[0].set(0.5 * delta_lna)
zero_k = jnp.zeros(k_axis.shape, dtype=sourceT0.dtype)
def scan_step(carry, xs_l):
acc_T0, acc_T1, acc_T2, acc_E = carry
sT0_l, sT1_l, sT2_l, sE_l, aH_l, tau_l, w_l = xs_l
chi_l = (tau0 - tau_l) * k_axis
phi0_l = phi0_local(chi_l)
phi1_l = phi1_local(chi_l)
phi2_l = phi2_local(chi_l)
eps_l = phi0_l / chi_l**2 * ell_eps_factor
inv_aH = 1.0 / aH_l
acc_T0 = acc_T0 + w_l * sT0_l * inv_aH * phi0_l
acc_T1 = acc_T1 + w_l * sT1_l * inv_aH * phi1_l
acc_T2 = acc_T2 + w_l * sT2_l * inv_aH * phi2_l
acc_E = acc_E + w_l * sE_l * inv_aH * eps_l
return (acc_T0, acc_T1, acc_T2, acc_E), None
init = (zero_k, zero_k, zero_k, zero_k)
xs = (sourceT0, sourceT1, sourceT2, sourceE, aH_1d, tau, weights)
# jax.checkpoint on the scan body: during reverse AD, body intermediates
# are not saved — the body is re-executed on the backward pass. Kills
# the ~21 GiB (Nell, Nlna, Nk) integrand rematerialisation; adds ~2× on
# this scan's compute, a small fraction of SS wall time.
(transferT0, transferT1, transferT2, transferE), _ = lax.scan(
jax.checkpoint(scan_step), init, xs
)
transferT = transferT0 + transferT1 + transferT2
### END OF TRANSFER FUNCTION ###
# Now we integrate the transfer functions along the line of sight, and return.
integrandTT = 4.*jnp.pi * params['A_s'] * (k_axis/self.k_pivot)**(params['n_s']-1.) * transferT**2 / k_axis
integrandTE = 4.*jnp.pi * params['A_s'] * (k_axis/self.k_pivot)**(params['n_s']-1.) * transferT*transferE / k_axis
integrandEE = 4.*jnp.pi * params['A_s'] * (k_axis/self.k_pivot)**(params['n_s']-1.) * transferE**2 / k_axis
return (
jnp.trapezoid(integrandTT, k_axis),
jnp.trapezoid(integrandTE, k_axis),
jnp.trapezoid(integrandEE, k_axis)
)