Source code for abcmb.background

import jax
from jax import config, vmap, lax
import numpy as np
import jax.numpy as jnp
import equinox as eqx
from diffrax import diffeqsolve, ODETerm, Kvaerno5, Tsit5, SaveAt, PIDController, ForwardMode
import optimistix as optx

from .hyrex.array_with_padding import array_with_padding
from .hyrex import recomb_functions
from .hyrex.hyrex import RecombInputs
from . import ABCMBTools as tools
from . import constants as cnst

import os
file_dir = os.path.dirname(__file__)
config.update("jax_enable_x64", True)


[docs] class BackgroundPreRecomb(eqx.Module): """ Pre-recombination background-cosmology object. Holds everything HyRex needs to run on CPU: (conformal-time tabulation, species list, and HyRex input arrays via ``RecombInputs`` object). Attributes: ----------- species_list : tuple A list of all fluids in the cosmology lna_tau_tab : jnp.array Log scale factor axis used to tabulate conformal time tau_tab : jnp.array Tabulated conformal time. tau0 : float Conformal time today in Mpc. recomb_inputs : RecombInputs Bundle of background quantities (TCMB, nH, H) sampled on ``RecModel.lna_axis_full``; consumed by HyRex. adjoint : diffrax.adjoint Adjoint mode for diffrax solves (static field). Methods: -------- rho_tot : Compute total energy density (units: eV cm^{-3}) P_tot : Compute total pressure (units: eV cm^{-3}) H : Compute Hubble parameter (units: s^{-1}) aH : Compute conformal Hubble parameter (units: Mpc^{-1}) aH_prime : Compute derivative of conformal Hubble (units: Mpc^{-1}) d2adtau2_over_a : Compute second derivative of scale factor (units: Mpc^{-2}) tau : Compute conformal time (units: Mpc) nH : Compute hydrogen number density (units: cm^{-3}) TCMB : Compute CMB temperature (units: eV) R_ratio_lna : Compute baryon drag ratio (units: dimensionless) """ species_list : tuple lna_tau_tab = jnp.linspace(-33.0, 0.0, 10000) # Axis for tabulating conformal time. tau_tab : jnp.array # Tabulated conformal time. tau0 : float # Conformal time today recomb_inputs : "RecombInputs" adjoint : "diffrax.adjoint" = eqx.field(static=True) def __init__(self, params, species_list, RecModel, adjoint=ForwardMode): """ Initialize pre-recombination background. Tabulates conformal time and builds the RecombInputs object for HyRex. Parameters: ----------- params : dict Cosmological parameters species_list : tuple List of fluid species for energy density calculations RecModel : hyrex.recomb_model Recombination module for computing xe and Tm histories adjoint : diffrax.adjoint, optional Adjoint class for diffrax solves (default: ForwardMode) """ self.adjoint = adjoint self.species_list = species_list self.tau_tab = self._tabulate_conformal_time(params) self.tau0 = self.tau(0.) # Bundle the background quantities HyRex needs onto its sampling # grid (acccording to the input RecModel) lna_axis = RecModel.lna_axis_full self.recomb_inputs = RecombInputs( lna_grid = lna_axis, TCMB_arr = vmap(self.TCMB, in_axes=[0, None])(lna_axis, params), nH_arr = vmap(self.nH, in_axes=[0, None])(lna_axis, params), H_arr = vmap(self.H, in_axes=[0, None])(lna_axis, params), )
[docs] def rho_tot(self, lna, params): """ Compute total energy density. Sums energy density over all species in the universe. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Total energy density (units: eV cm^{-3}) """ rho_tot = 0. for i in range(len(self.species_list)): rho_tot += self.species_list[i].rho(lna, params) return rho_tot
[docs] def P_tot(self, lna, params): """ Compute total pressure. Sums pressure over all species in the universe. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Total pressure (units: eV cm^{-3}) """ P_tot = 0. for i in range(len(self.species_list)): P_tot += self.species_list[i].P(lna, params) return P_tot
[docs] def H(self, lna, params): """ Compute Hubble parameter. Uses Einstein equation H = sqrt(8πG/3 ρ_tot) to account for novel species without well-defined density parameters. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Hubble parameter (units: s^{-1}) """ return jnp.sqrt(8.*jnp.pi*cnst.G*self.rho_tot(lna, params)/3.)
[docs] def aH(self, lna, params): """ Compute conformal Hubble parameter. Calculates conformal Hubble H = a*H = da/dτ where τ is conformal time. Uses Mpc units for perturbation calculations. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Conformal Hubble parameter (units: Mpc^{-1}) """ return jnp.exp(lna)*self.H(lna, params) / cnst.c_Mpc_over_s
[docs] def aH_prime(self, lna, params): """ Compute derivative of conformal Hubble parameter. Uses second Friedmann equation to compute d(aH)/d(ln a). See Eq.(20) of arXiv:9506072. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Derivative of conformal Hubble (units: Mpc^{-1}) """ return -4.*jnp.pi*cnst.G*jnp.exp(lna)**2/3./self.aH(lna, params) * (self.rho_tot(lna,params)+3.*self.P_tot(lna, params)) / cnst.c_Mpc_over_s**2
[docs] def d2adtau2_over_a(self, lna, params): """ Compute second derivative of scale factor. Calculates d²a/dτ²/a where τ is conformal time. Appears in perturbation evolution equations. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Second derivative of scale factor (units: Mpc^{-2}) """ return self.aH(lna, params)**2 + self.aH(lna, params)*self.aH_prime(lna, params)
def _dtau_dlna(self, lna, y, args): """ Compute derivative of conformal time with respect to ln(a). Parameters: ----------- lna : float Logarithm of scale factor y : float Current conformal time value args : dict Cosmological parameters Returns: -------- float Derivative dτ/d(ln a) (units: Mpc) """ params = args return 1./self.aH(lna, params) def _tabulate_conformal_time(self, params): """ Tabulate conformal time as function of ln(a). Integrates dτ/d(ln a) = 1/aH from early times to today using radiation-dominated initial conditions. We stitch an analytic early-time solution to a Diffrax dense interpolation, taking care not to evaluate out of bounds. Parameters: ----------- params : dict Cosmological parameters Returns: -------- array Tabulated conformal time values (units: Mpc) """ lna_cut = -16.1 # use analytic approx before this # Analytic early-time approximation tau_approx = lambda lna: ( jnp.exp(lna) / (cnst.H0_over_h / cnst.c_Mpc_over_s) / jnp.sqrt(params["omega_r"]) ) lna_end = self.lna_tau_tab[-1] # ---- Diffrax solve (dense interpolation) ---- term = ODETerm(self._dtau_dlna) controller = PIDController(rtol=1e-8, atol=1e-8) saveat = SaveAt(dense=True) adjoint=self.adjoint() sol = diffeqsolve( term, solver=Kvaerno5(), t0=lna_cut, t1=lna_end, dt0=1e-5, y0=tau_approx(lna_cut), saveat=saveat, stepsize_controller=controller, args=params, adjoint=adjoint, ) # Numerical jitter causes this interpolation to go out of bounds on # some machines, so we do some extra work to safeguard that here: # Strictly inside [lna_cut, lna_end); avoid touching internal sol.ts (may be None). # nextafter gets the next representable float below lna_end to ensure in-bounds. lna_hi = jnp.nextafter(lna_end, -jnp.inf) def _tau_from_sol(l): l_in = jnp.clip(l, lna_cut, lna_hi) return sol.evaluate(l_in) def _tau_combined(l): # cond is faster than where since untaken branch is not evaluated return lax.cond(l > lna_cut, _tau_from_sol, tau_approx, l) tau_tab = vmap(_tau_combined)(self.lna_tau_tab) # Replace any remaining non-finite entries with analytic fallback tau_tab = jnp.where(jnp.isfinite(tau_tab), tau_tab, vmap(tau_approx)(self.lna_tau_tab)) return tau_tab
[docs] def tau(self, lna): """ Compute conformal time. Interpolates from pre-tabulated conformal time history. Conformal time τ satisfies dτ = dt/a where t is cosmic time. Parameters: ----------- lna : float Logarithm of scale factor Returns: -------- float Conformal time (units: Mpc) Notes: ------ IDEA: Make Background a repeatedly initiated module with both species_list and params stored. Upon initiation, a full history of conformal time is calculated with diffrax and stored for interpolation. This can be done by approximating early time with radiation approximation, and starting diffrax integration at the early time with appropriate initial conditions. """ return tools.fast_interp(lna, self.lna_tau_tab[0], self.lna_tau_tab[-1], self.tau_tab)
[docs] def nH(self, lna, params): """ Compute hydrogen number density. Calculates total hydrogen number density at given redshift. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Hydrogen number density (units: cm^{-3}) """ return (1-params['YHe']) * 3. * params['omega_b'] * cnst.H0_over_h**2 / 8 / jnp.pi / cnst.G / cnst.mH / jnp.exp(lna)**3
[docs] def TCMB(self, lna, params): """ Compute CMB temperature. Calculates CMB temperature at given redshift using T ∝ 1/a scaling. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float CMB temperature (units: eV) """ return params['TCMB0'] / jnp.exp(lna)
[docs] def R_ratio_lna(self, lna, params): """ Compute baryon drag ratio. Calculates R = 3ρ_b/(4ρ_γ), the ratio of baryon to photon energy densities that appears in baryon drag calculations. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Baryon drag ratio (units: dimensionless) """ rho_b = 0. rho_g = 0. for s in self.species_list: if s.name == "Photon": rho_g += s.rho(lna, params) elif s.name == "Baryon": rho_b += s.rho(lna, params) return 3. * rho_b / (4 * rho_g)
[docs] class Background(BackgroundPreRecomb): """ Full Background cosmology module for cosmological calculations. Inherits all cosmology fields and methods from ``BackgroundPreRecomb``. Construction takes a ``BackgroundPreRecomb`` and the recombination output from HyRex, then applies reionization and integrates the optical depth. This factorization allows HyRex to always run on CPU (its faster backend). Attributes: ----------- species_list : tuple A list of all fluids in the cosmology lna_tau_tab : jnp.array Log scale factor axis used to tabulate conformal time tau_tab : jnp.array Tabulated conformal time. tau0 : float Conformal time today in Mpc. recomb_inputs : RecombInputs Bundle of background quantities (TCMB, nH, H) sampled on ``RecModel.lna_axis_full``; consumed by HyRex. adjoint : diffrax.adjoint Adjoint mode for diffrax solves (static field). xe_tab : array_with_padding Tabulated free electron fraction xe with reionization correction. lna_xe_tab : array_with_padding Log scale factor axis corresponding to tabulated xe values. Tm_tab : array_with_padding Tabulated matter temperature Tm during recombination. lna_Tm_tab : array_with_padding Log scale factor axis corresponding to tabulated Tm values. kappa_func : diffrax.solution Optical depth function (dense interpolation). z_reion : float Redshift of hydrogen reionization in the CAMB parameterization. tau_reion : float Optical depth to reionization. lna_rec : float Log scale factor of recombination. rA_rec : float Comoving angular diameter distance at recombination in Mpc. lna_transfer_start : float Log scale factor at which to begin integrating transfer functions. lna_visibility_stop : float Log scale factor at which to stop integrating T1, T2, and E sources due to small visibility functions. Only used for l<400. Methods: -------- rho_tot : Compute total energy density (units: eV cm^{-3}) P_tot : Compute total pressure (units: eV cm^{-3}) H : Compute Hubble parameter (units: s^{-1}) aH : Compute conformal Hubble parameter (units: Mpc^{-1}) aH_prime : Compute derivative of conformal Hubble (units: Mpc^{-1}) d2adtau2_over_a : Compute second derivative of scale factor (units: Mpc^{-2}) tau : Compute conformal time (units: Mpc) xe : Compute free electron fraction (units: dimensionless) Tm : Compute matter temperature (units: eV) tau_c : Compute Thomson scattering time (units: Mpc) expmkappa : Compute exp(-kappa) (units: dimensionless) visibility : Compute visibility function (units: Mpc^{-1}) z_d : Compute baryon decoupling redshift (units: dimensionless) rs_d : Compute sound horizon at decoupling (units: Mpc) """ xe_tab : "array_with_padding" lna_xe_tab : "array_with_padding" Tm_tab : "array_with_padding" lna_Tm_tab : "array_with_padding" kappa_func : "diffrax.solution" z_reion : float tau_reion : float lna_rec : float rA_rec : float # Comoving angular diameter distance at recombination. # Transfer related lna_transfer_start : float # Time where transfer functions start integrating. lna_visibility_stop : float # Time to stop integrating T1, T2, and E sources due to small visibility functions. Only used for l<400 def __init__(self, pre_BG, recomb_output, params, ReionModel): """ Initialize Background cosmology module. Consolidates pre-recombination and recombination elements of background cosmology. Parameters: ----------- pre_BG : BackgroundPreRecomb Output of the pre-recomb stage; provides species_list, tau_tab, tau0, recomb_inputs, adjoint. recomb_output : tuple HyRex output ``(xe, lna_xe, Tm, lna_Tm)`` quadruple params : dict Cosmological parameters. ReionModel : callable Reionization module for computing the xe correction. """ # Copy pre-recomb fields onto self. self.adjoint = pre_BG.adjoint self.species_list = pre_BG.species_list self.tau_tab = pre_BG.tau_tab self.tau0 = pre_BG.tau0 self.recomb_inputs = pre_BG.recomb_inputs # Unpack HyRex output and apply reionization. xe, self.lna_xe_tab, self.Tm_tab, self.lna_Tm_tab = recomb_output reion_model = ReionModel(self, params) self.z_reion = reion_model.z_reion self.tau_reion = reion_model.tau_reion xe_reion_correction = reion_model.xe_reion(self.lna_xe_tab.arr, self.z_reion, params) xe_full_arr = xe_reion_correction + xe.arr self.xe_tab = array_with_padding(xe_full_arr) # Replace inf padding in the recomb tabs with `lastval`. Forward # selects the same branch either way (the `where` in BG.xe/BG.Tm # gates the fast_interp dead branch out for lna in range). The inf # otherwise poisons the lensing=True reverse-AD cotangent: under # Kvaerno5+VeryChord, the IFT replay materializes the stage Jacobian # via vmap(jvp(RHS)) and chains a cotangent through the where's dead # branch into fast_interp past `lastnum`, giving 0×inf = NaN. def _finite_pad(awp): finite_arr = jnp.where(jnp.isinf(awp.arr), awp.lastval, awp.arr) return eqx.tree_at(lambda t: t.arr, awp, finite_arr) self.xe_tab = _finite_pad(self.xe_tab) self.lna_xe_tab = _finite_pad(self.lna_xe_tab) self.Tm_tab = _finite_pad(self.Tm_tab) self.lna_Tm_tab = _finite_pad(self.lna_Tm_tab) self.kappa_func = self._tabulate_optical_depth(params) # Find approximate maximum of visibility function. lna_vals = jnp.linspace(-8.0, -4.0, 1500) # Decoupling falls in here. vis_vals = vmap(self.visibility, in_axes=[0, None])(lna_vals, params) self.lna_rec = lna_vals[jnp.argmax(vis_vals)] self.lna_visibility_stop = lna_vals[jnp.argmin((vis_vals - 1.e-3)**2)] self.rA_rec = self.tau0 - self.tau(self.lna_rec) # Find approximate early time when aH x tau_c = 0.008 lna_vals = jnp.linspace(-15.0, -6.0, 5000) aH_tau_c_vals = vmap(self.aH, in_axes=[0, None])(lna_vals, params) * self.tau_c(lna_vals, params) self.lna_transfer_start = lna_vals[jnp.argmin((aH_tau_c_vals-0.008)**2)] ### RECOMBINATION RELATED ###
[docs] def xe(self, lna): """ Compute free electron fraction. Interpolates from pre-tabulated recombination history with boundary conditions for early and late times. Parameters: ----------- lna : float Logarithm of scale factor Returns: -------- float Free electron fraction (units: dimensionless) Notes: ------ The logic flow is equivalent to: if lna < self.lna_xe_tab.arr[0]: return self.xe_tab[0] elif lna > self.lna_xe_tab.lastval: return self.xe_tab.lastval else: return jnp.interp(lna, self.lna_xe_tab, self.xe_tab) """ return jnp.where( lna < self.lna_xe_tab.arr[0], self.xe_tab.arr[0], jnp.where( lna >= self.lna_xe_tab.lastval, self.xe_tab.lastval, tools.fast_interp(lna, self.lna_xe_tab.arr[0], self.lna_xe_tab.arr[0] + len(self.lna_xe_tab.arr) * (self.lna_xe_tab.arr[1]-self.lna_xe_tab.arr[0]), self.xe_tab.arr) ) )
def _Tm_early_approx(self, lna, params): """ Compute matter temperature using post-equilibrium approximation. Uses approximation Tm = TCMB * (1 - H/GammaCompton) for early times before detailed recombination calculation begins. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Matter temperature (units: eV) """ TCMB = self.TCMB(lna, params) xe = self.xe(lna) return TCMB * (1.-self.H(lna,params)/recomb_functions.Gamma_compton(xe, TCMB, params['YHe']))
[docs] def Tm(self, lna, params): """ Compute matter temperature. Interpolates from pre-tabulated recombination history with early-time approximation and late-time boundary conditions. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Matter temperature (units: eV) """ return jnp.where( lna < self.lna_Tm_tab.arr[0], self._Tm_early_approx(lna, params), jnp.where( lna >= self.lna_Tm_tab.lastval, self.Tm_tab.lastval, tools.fast_interp(lna, self.lna_Tm_tab.arr[0], self.lna_Tm_tab.arr[0] + len(self.lna_Tm_tab.arr) * (self.lna_Tm_tab.arr[1]-self.lna_Tm_tab.arr[0]), self.Tm_tab.arr) ) )
[docs] def tau_c(self, lna, params): """ Compute Thomson scattering time. Calculates Thomson scattering time scale τc = 1/(a × ne × σT). Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Thomson scattering time (units: Mpc) """ a = jnp.exp(lna) nH = self.nH(lna, params) ne = nH * self.xe(lna) return 1./a/ne/cnst.thomson_xsec/cnst.c*cnst.c_Mpc_over_s
def _tabulate_optical_depth(self, params): """ Tabulate optical depth from given scale factor to today. Integrates dκ/d(ln a) = -1/(τc × aH) backwards from today to compute optical depth κ(a) = ∫[a to 1] dκ/da' da'. Parameters: ----------- params : dict Cosmological parameters Returns: -------- array Tabulated optical depth values (units: dimensionless) Notes: ------ Also computes time derivative of optical depth, which is the integrand involving the free electron fraction. """ integrand = lambda lna, y, args: -1./self.tau_c(lna, params)/self.aH(lna, params) term = ODETerm(integrand) stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=1.e-10, atol=1.e-10) adjoint=self.adjoint() sol = diffeqsolve( term, solver=Kvaerno5(), stepsize_controller=stepsize_controller, t0=0., t1=-10., dt0=-1.e-3, max_steps=2048, y0=0.0, saveat=SaveAt(dense=True), adjoint=adjoint ) return sol
[docs] def expmkappa(self, lna): """ Compute exp(-optical depth). Interpolates from pre-tabulated optical depth history. Parameters: ----------- lna : float Logarithm of scale factor Returns: -------- float exp(-(optical depth)) (units: dimensionless) """ return jnp.where( lna < -10., 0., jnp.exp(-self.kappa_func.evaluate(lna)) )
[docs] def visibility(self, lna, params): """ Compute visibility function. Calculates visibility function g(x) = -aH(x) × κ'(x) × exp(-κ(x)) where ' = d/dx and x = ln a. Represents probability that a CMB photon observed today was last scattered at time x. Parameters: ----------- lna : float Logarithm of scale factor params : dict Cosmological parameters Returns: -------- float Visibility function (units: Mpc^{-1}) Notes: ------ Used in computing source functions for CMB anisotropies. """ return self.expmkappa(lna)/self.tau_c(lna, params)
########################################### ### tools for computing decoupling time ### ###########################################
[docs] def find_z_at_kappad_equals_one(self, z, kappa_d): """ Find redshift where baryon optical depth equals unity. Interpolates to find z_d such that κ_d(z_d) = 1, marking the approximate time of baryon decoupling. Parameters: ----------- z : array Redshift array kappa_d : array Baryon optical depth array Returns: -------- float Decoupling redshift (units: dimensionless) """ # ensure sorted ascending idx = jnp.argsort(z) z_sorted = z[idx] kappa_d_sorted = jnp.abs(kappa_d)[idx] z_d = jnp.interp(1.0, kappa_d_sorted, z_sorted) return z_d
[docs] def interp_rs_at_z(self, z_bg, r_s, z_d): """ Interpolate sound horizon at decoupling redshift. Parameters: ----------- z_bg : array Background redshift array r_s : array Sound horizon array z_d : float Decoupling redshift Returns: -------- float Sound horizon at decoupling (units: Mpc) """ idx = jnp.argsort(z_bg) z_sorted = z_bg[idx] rs_sorted = r_s[idx] return jnp.interp(z_d, z_sorted, rs_sorted)
def _tabulate_kappa_d(self, params): """ Tabulate baryon optical depth. Integrates dκ_d/d(ln a) = -1/(τc × aH × R) backwards from today to compute baryon optical depth including drag effects. Parameters: ----------- params : dict Cosmological parameters Returns: -------- array Tabulated baryon optical depth values (units: dimensionless) """ integrand = lambda lna, y, args: jnp.float64(-1./self.tau_c(lna, params)/self.aH(lna, params)/(self.R_ratio_lna(lna, params))) term = ODETerm(integrand) stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=1.e-3, atol=1.e-6) adjoint=self.adjoint() solution = diffeqsolve( term, solver=Tsit5(), # Kvaerno5 is just slower but gives same result stepsize_controller=stepsize_controller, t0=self.lna_tau_tab[-1], # Initial x value (~0 in this case) t1=self.lna_tau_tab[0], # Final x value (smallest x value) dt0=-1e-3, max_steps=2048, y0=0.0, # Initial value tau(x=0) = 0 saveat=SaveAt(ts=self.lna_tau_tab[::-1]), # Save at all points in x, reverse order since integrating backwards adjoint=adjoint ) result = solution.ys[::-1] return result def _tabulate_rs(self, params): """ Tabulate sound horizon evolution. Integrates drs/d(ln a) = cs/aH from early times to today where cs = 1/√(3(1+R)) accounts for baryon loading. Parameters: ----------- params : dict Cosmological parameters Returns: -------- array Tabulated sound horizon values (units: Mpc) """ # initial condition assuming cs**2 = 1/3 at early times rs0 = 1./jnp.sqrt(3) / (self.aH( self.lna_tau_tab[0], params )) integrand = lambda lna, y, args: 1./jnp.sqrt(3*(1+self.R_ratio_lna(lna, params))) / (self.aH(lna, params)) term = ODETerm(integrand) stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=1.e-3, atol=1.e-6) adjoint=self.adjoint() solution = diffeqsolve( term, solver=Tsit5(), stepsize_controller=stepsize_controller, t0=self.lna_tau_tab[0], # reversed direction since I know rs at early times t1=self.lna_tau_tab[-1], dt0=1e-3, max_steps=2048, y0=rs0, saveat=SaveAt(ts=self.lna_tau_tab), adjoint=adjoint ) result = solution.ys return result
[docs] def z_d(self, params): """ Compute baryon decoupling redshift. Finds redshift where κ_d = 1 as estimate of when baryons decouple from photons. Parameters: ----------- params : dict Cosmological parameters Returns: -------- float Decoupling redshift (units: dimensionless) """ return self.find_z_at_kappad_equals_one(1/jnp.exp(self.lna_tau_tab) - 1, self._tabulate_kappa_d(params))
[docs] def rs_d(self, params): """ Compute sound horizon at decoupling. Finds value of sound horizon at baryon decoupling redshift z_d. Parameters: ----------- params : dict Cosmological parameters Returns: -------- float Sound horizon at decoupling (units: Mpc) """ return self.interp_rs_at_z(1/jnp.exp(self.lna_tau_tab) - 1, self._tabulate_rs(params), self.z_d(params))
[docs] class ReionizationModel(eqx.Module): """ Object for computing the reionization correction to the free electron fraction. Provides the base methods xe_reion : calculates the tanh electron fraction correction at redshifts lna, given z_reion and params tau_reion_fn : calculates the optical depth to reionization. At the moment we only support the CAMB tanh parameterization, but we need different approaches based on whether the use inputs the optical depth tau_reion or the reionization redshift z_reion. """ z_reion : jnp.float64 tau_reion : jnp.float64
[docs] def xe_reion(self, lna, z_reion, params): """ Passing in an lna array should get you the correct tanh patching based on the reionization parameter. """ fHe = params['YHe'] / 4 / (1-params['YHe']) z = 1/jnp.exp(lna) - 1 y = (1+z)**(params["exp_reion"]) y_reion = (1+z_reion)**(params["exp_reion"]) Delta_y_reion = params["exp_reion"] * (1+z_reion)**(params["exp_reion"]-1) * params["Delta_z_reion"] tanh_arg = (y_reion - y) / Delta_y_reion xe_reion_H = (1+fHe)/2 * (1 + jnp.tanh(tanh_arg)) # The above accounts for hydrogen and the first ionization level of helium. # Let's also account for the second ionization of helium: tanh_arg_He = (params["z_reion_He"] - z)/params["Delta_z_reion_He"] xe_reion_HeII = fHe/2 * (1 + jnp.tanh(tanh_arg_He)) return xe_reion_H + xe_reion_HeII
def tau_reion_fn(self, z_reion, BG, params): lna_axis = jnp.linspace(-5., 0., 2000) xe_reion_correction = self.xe_reion(lna_axis, z_reion, params) # Free electron number density belonging only to reionized hydrogen. ne = BG.nH(lna_axis, params) * xe_reion_correction Gamma = jnp.exp(lna_axis)*ne*cnst.thomson_xsec*cnst.c/cnst.c_Mpc_over_s aH = BG.aH(lna_axis, params) # Optical depth integrand integrand = Gamma/aH return jnp.trapezoid(integrand, lna_axis)
[docs] class ReionizationModelFromZ(ReionizationModel): """ Concrete extension of the base ReionizationModel Class. This object is used when the user direcly inputs the redshift of reionization. In this case the tanh correction and the optical depth can be computed directly, and simply returned. """ def __init__(self, BG, params): self.z_reion = params.get("z_reion", jnp.array(7.6711)) self.tau_reion = self.tau_reion_fn(self.z_reion, BG, params)
[docs] class ReionizationModelFromTau(ReionizationModel): """ Concrete extension of the base ReionizationModel Class. This object is used when the user inputs the optical depth and wishes to infer the redshift. The init finder will use an optimistix root finder to find the appropriate redshift. Then the appropriate tanh correction may be called and returned, as well as the inferred reionization redshift. """ def __init__(self, BG, params): def tau_target_fn(z_reion, args): target = args return self.tau_reion_fn(z_reion, BG, params) - target solver = optx.Newton(rtol=1e-5, atol=1e-5) sol = optx.root_find(tau_target_fn, solver, 7.6, params.get("tau_reion", jnp.array(0.05430842))) self.z_reion = sol.value self.tau_reion = params.get("tau_reion", jnp.array(0.05430842))