#import abc
from jax import config, lax, vmap
import jax.numpy as jnp
import equinox as eqx
from . import constants as cnst
config.update("jax_enable_x64", True)
### ABSTRACT BASE CLASSES AND INTERFACES ###
[docs]
class Fluid(eqx.Module):
"""
Base class for fluid species.
Defines fluid properties.
Fields:
-------
first_idx : int
Default = 0
Position of the first perturbation equation in the Diffrax vector. For most fluids this is the density perturbation mode "delta".
num_equations : int
Default = 0
Number of equations that need to be simultaneously evolved in the perturbations module.
name : str
Default = ""
Name of the fluid, used to find fluid and refer to it later in the computation using species_dict["name"].
is_matter : bool
Default = False
Whether the fluid is non-relativistic today and contributes towards the total matter power spectrum.
Methods:
--------
rho : Compute energy density (units: eV cm^{-3})
P : Compute pressure (units: eV cm^{-3})
w : Compute equation of state parameter (units: dimensionless)
y_ini : Adiabatic initial conditions, in synchronous gauge
y_prime : Perturbation derivatives, in synchronous gauge
rho_delta : Perturbed density function δρ (units: eV cm^{-3})
rho_plus_P_theta : Velocity perturbation (units: eV cm^{-3} Mpc^{-1})
rho_plus_P_sigma : Compute standard shear perturbation (units: eV cm^{-3})
"""
first_idx : int = eqx.field(default=0, static=True)
num_equations : int = eqx.field(default=0, static=True)
name : str = eqx.field(default="", static=True)
is_matter : bool = eqx.field(default=False, static=True) # Does the fluid contribute towards matter overdensity today.
def __init__(self, first_idx, specs):
self.first_idx = first_idx
self.name = ""
self.is_matter = False
[docs]
def rho(self, lna, args):
"""
Compute energy density.
Calculates the energy density of the fluid species at a given
cosmological epoch using the logarithm of the scale factor.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Energy density (units: eV cm^{-3})
"""
raise NotImplementedError("Fluid species must implement an energy density function.")
[docs]
def P(self, lna, args):
"""
Compute pressure.
Calculates the pressure of the fluid species at a given
cosmological epoch using the logarithm of the scale factor.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Pressure (units: eV cm^{-3})
"""
raise NotImplementedError("Fluid species must implement a pressure function.")
[docs]
def w(self, lna, args):
"""
Compute equation of state parameter.
Calculates the ratio of pressure to energy density, representing
the equation of state for the fluid species.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Equation of state parameter (units: dimensionless)
"""
return self.P(lna, args)/self.rho(lna, args)
[docs]
def y_ini(self, k, tau_ini, args):
"""
Compute initial conditions for perturbation modes.
Calculates the initial state of perturbation modes at early cosmological times.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
tau_ini : float
Initial conformal time (units: Mpc)
args : dict
Cosmological parameters (params)
Returns:
--------
array
Initial perturbation mode values
"""
raise NotImplementedError("Fluid species must implement the initial conditions of their perturbation modes.")
[docs]
def y_prime(self, k, lna, metric_h_prime, metric_eta_prime, y, args):
"""
Compute time derivatives of perturbation modes.
Calculates how perturbation modes evolve with cosmological time.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
lna : float
Logarithm of scale factor
metric_h_prime : float
Derivative of metric h
metric_eta_prime : float
Derivative of metric eta
y : array
Current perturbation mode values
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
array
Time derivatives of perturbation modes
"""
raise NotImplementedError("Fluid species must implement a perturbation derivative function.")
[docs]
def rho_delta(self, lna, y, args):
"""
Compute density perturbation.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Density perturbation (units: eV cm^{-3})
"""
raise NotImplementedError("Fluid species must implement a perturbation derivative function.")
[docs]
def rho_plus_P_theta(self, lna, y, args):
"""
Compute velocity perturbation.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Velocity perturbation (units: eV cm^{-3} Mpc^{-1})
"""
raise NotImplementedError("Fluid species must implement a perturbation derivative function.")
[docs]
def rho_plus_P_sigma(self, lna, y, args):
"""
Compute shear perturbation.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Shear perturbation (units: eV cm^{-3})
"""
raise NotImplementedError("Fluid species must implement a perturbation derivative function.")
[docs]
def output_perturbations(self, lna, modes, args):
"""
Return named perturbation arrays for storage in PerturbationTable.
Each concrete species overrides this to select the physically
meaningful subset of its modes. Species with no perturbations
(e.g. dark energy) return an empty dict via this base implementation.
Parameters:
-----------
lna : array, shape (Nlna,)
Logarithm of scale factor grid
modes : array, shape (Ny, Nlna, Nk)
Full perturbation state, already transposed
args : tuple
(BG, params) — background cosmology and cosmological parameters
Returns:
--------
dict
{quantity_name: array(Nlna, Nk)}. Empty for background-only species.
"""
return {}
[docs]
class StandardFluid(Fluid):
"""
Standard implementation of perturbation methods for fluid species.
Provides default computations for perturbation-related methods
used in this code.
Methods:
--------
rho_delta : Compute standard density perturbation (units: eV cm^{-3})
rho_plus_P_theta : Compute standard velocity perturbation (units: eV cm^{-3} Mpc^{-1})
rho_plus_P_sigma : Compute standard shear perturbation (units: eV cm^{-3})
"""
def __init__(self, first_idx, specs):
super().__init__(first_idx, specs)
[docs]
def get_delta(self, lna, y, args):
"""
Getter method for density perturbation from perturbation equations vector
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Dimensionless density perturbation (units: None)
"""
return y[self.first_idx]
[docs]
def get_theta(self, lna, y, args):
"""
Getter method for velocity divergence perturbation from perturbation equations vector
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Velocity divergence perturbation (units: 1/Mpc)
"""
return jnp.where(
self.num_equations > 1,
y[self.first_idx+1],
0
)
[docs]
def get_sigma(self, lna, y, args):
"""
Getter method for shear perturbation from perturbation equations vector
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Dimensionless shear perturbation (units: None)
"""
return jnp.where(
self.num_equations > 2,
y[self.first_idx+2],
0
)
# Called by diffrax, child classes should never override. Okay to implement here.
[docs]
def rho_delta(self, lna, y, args):
"""
Compute energy density perturbation, contribution to metric perturbation evolution.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Energy density perturbation (units: eV cm^{-3})
"""
params = args
return self.rho(lna, params) * self.get_delta(lna, y, args)
[docs]
def rho_plus_P_theta(self, lna, y, args):
"""
Compute velocity perturbation times the sum of energy density and pressure. {0, i} component
of the perturbed stress energy tensor.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Velocity perturbation (units: eV cm^{-3} Mpc^{-1})
"""
params = args
return (self.rho(lna, params)+self.P(lna, params)) * self.get_theta(lna, y, args)
[docs]
def rho_plus_P_sigma(self, lna, y, args):
"""
Compute shear stress perturbation, needed for CMB
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Shear stress perturbation (units: eV cm^{-3})
"""
params = args
return (self.rho(lna, params)+self.P(lna, params)) * self.get_sigma(lna, y, args)
[docs]
class BackgroundFluid(Fluid):
num_equations = 0
def __init__(self, first_idx, specs):
super().__init__(first_idx, specs)
[docs]
def y_ini(self, k, tau_ini, args):
"""
Trivial initial condition vector for background.
"""
return jnp.array([])
[docs]
def y_prime(self, k, lna, metric_h_prime, metric_eta_prime, y, args):
"""
Trivial derivative vector for background.
"""
return jnp.array([])
[docs]
def rho_delta(self, lna, y, args):
return 0.
[docs]
def rho_plus_P_theta(self, lna, y, args):
return 0.
[docs]
def rho_plus_P_sigma(self, lna, y, args):
return 0.
### BEGINNING OF CONCRETE CLASSES ###
[docs]
class DarkEnergy(BackgroundFluid):
"""
Dark energy fluid species implementation.
Represents a constant energy density fluid with negative pressure.
Required input parameters: None.
Required derived parameters: params['omega_Lambda']
Methods:
--------
rho : Compute dark energy density (units: eV cm^{-3})
P : Compute dark energy pressure (units: eV cm^{-3})
"""
name = "DarkEnergy"
def __init__(self, first_idx, specs):
super().__init__(first_idx, specs)
self.name = "DarkEnergy"
[docs]
def rho(self, lna, args):
"""
Compute dark energy density.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Dark energy density (units: eV cm^{-3})
"""
params = args
return params['omega_Lambda'] * (3.*cnst.H0_over_h**2/8./jnp.pi/cnst.G)
[docs]
def P(self, lna, args):
"""
Compute dark energy pressure.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Dark energy pressure (units: eV cm^{-3})
"""
params = args
return -self.rho(lna, params)
[docs]
class ColdDarkMatter(StandardFluid):
"""
Cold dark matter fluid species implementation.
Non-relativistic, pressureless dark matter with density
perturbations but no velocity or shear modes.
Required input parameters: params['omega_cdm'].
Required derived parameters: params['om'].
Methods:
--------
rho : Compute cold dark matter density (units: eV cm^{-3})
P : Compute cold dark matter pressure (units: eV cm^{-3})
y_ini : Compute initial perturbation conditions
y_prime : Compute perturbation time derivatives
"""
name = "ColdDarkMatter"
num_equations = 1 # CDM only receives density perturbation in synchronous gauge.
is_matter = True
def __init__(self, first_idx, specs):
super().__init__(first_idx, specs)
self.name = "ColdDarkMatter"
self.is_matter = True
[docs]
def rho(self, lna, args):
"""
Compute cold dark matter density.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Cold dark matter density (units: eV cm^{-3})
"""
params = args
return params['omega_cdm'] * (3.*cnst.H0_over_h**2/8./jnp.pi/cnst.G) / jnp.exp(lna)**3
[docs]
def P(self, lna, args):
"""
Compute cold dark matter pressure.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Cold dark matter pressure (units: eV cm^{-3})
Notes:
------
Cold dark matter is pressureless, so this always returns zero.
"""
return 0.
[docs]
def y_ini(self, k, tau_ini, args):
"""
Compute initial conditions for cold dark matter perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
tau_ini : float
Initial conformal time (units: Mpc)
args : dict
Cosmological parameters (params)
Returns:
--------
array
Initial density perturbation (units: dimensionless)
"""
params = args
delta = -(k*tau_ini)**2/4. * (1.-params["om"]*tau_ini/5.)
return jnp.array([delta])
[docs]
def y_prime(self, k, lna, metric_h_prime, metric_eta_prime, y, args):
"""
Compute time derivatives of cold dark matter perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
lna : float
Logarithm of scale factor
metric_h_prime : float
Derivative of metric h
metric_eta_prime : float
Derivative of metric eta
y : array
Current perturbation mode values
args : tuple
Background cosmology and cosmological parameters (BG, params) - Note: BG parameter is unused in this implementation
Returns:
--------
array
Time derivative of density perturbation (units: dimensionless)
"""
return jnp.array([-0.5*metric_h_prime])
[docs]
def output_perturbations(self, lna, modes, args):
return {"delta": modes[self.first_idx]}
[docs]
class MasslessNeutrino(StandardFluid):
"""
Massless neutrinos fluid species implementation.
Represents relativistic neutrinos with multiple angular momentum modes.
Required input parameters: params['N_nu_massless'], params['T_nu_massless'], params['TCMB0']
Required derived parameters: params['R_nu'], params['om']
Methods:
--------
rho : Compute neutrino density (units: eV cm^{-3})
P : Compute neutrino pressure (units: eV cm^{-3})
"""
name = "MasslessNeutrino"
def __init__(self, first_idx, specs):
super().__init__(first_idx, specs)
self.name = "MasslessNeutrino"
self.num_equations = specs["l_max_massless_nu"] + 1
[docs]
def rho(self, lna, args):
"""
Compute neutrino density.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Neutrino density (units: eV cm^{-3})
"""
params = args
a = jnp.exp(lna)
rho = params['N_nu_massless'] * 2. * 7./8. * jnp.pi**2/30. * params['T_nu_massless']**4 * params['TCMB0']**4 / a**4 # eV^4
rho = rho / (cnst.c * cnst.hbar)**3 # Convert to eV cm^{-3}
return rho
[docs]
def P(self, lna, args):
"""
Compute neutrino pressure.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Neutrino pressure (units: eV cm^{-3})
"""
params = args
return self.rho(lna, params)/3.
[docs]
def y_ini(self, k, tau_ini, args):
"""
Compute initial conditions for massless neutrino perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
tau_ini : float
Initial conformal time (units: Mpc)
args : dict
Cosmological parameters (params)
Returns:
--------
array
Initial perturbation mode values (units: 1/Mpc for theta, else dimensionless)
"""
params = args
R_nu = params['R_nu']
delta = - (k*tau_ini)**2/3. * (1.-params["om"]*tau_ini/5.)
theta = - k*(k*tau_ini)**3/36./(4.*R_nu+15.) \
* (4.*R_nu+11.+12.-3.*(8.*R_nu**2+50.*R_nu+275.)/20./(2.*R_nu+15.)*tau_ini*params["om"])
sigma = (k*tau_ini)**2/(45.+12.*R_nu) * 2. * (1.+(4.*R_nu-5.)/4./(2.*R_nu+15.)*tau_ini*params["om"])
# Return the four non-zero ell modes, and all higher ell-modes are zero to start.
# For the neutrinos we track Fnu_2 = 2*sigma, for better structure within the hierarchy.
return jnp.concatenate((jnp.array([delta, theta, sigma]), jnp.zeros(self.num_equations-3)))
[docs]
def y_prime(self, k, lna, metric_h_prime, metric_eta_prime, y, args):
"""
Compute time derivatives of massless neutrino perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
lna : float
Logarithm of scale factor
metric_h_prime : float
Derivative of metric h
metric_eta_prime : float
Derivative of metric eta
y : array
Current perturbation mode values
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
array
Time derivatives of perturbation modes (units: 1/Mpc for theta, else dimensionless)
"""
BG, params, _, _ = args
aH = BG.aH(lna, params)
tau = BG.tau(lna)
L = jnp.arange(self.num_equations) + self.first_idx
F = y[L]
delta = F[0]
theta = F[1]
sigma = F[2]
# density, velocity, shear perturbations
delta_prime = -4./3./aH*theta - 2./3.*metric_h_prime
theta_prime = k**2/aH*(delta/4.-sigma)
sigma_prime = 4./15./aH*theta - 3./10.*k/aH*F[3] + 2./15.*metric_h_prime + 4./5.*metric_eta_prime
F3_prime = 1./7. * k/aH * (6.*sigma - 4.*F[4])
# Rest of the Boltzmann Hierarchy
lmax = self.num_equations-1
L = jnp.arange(4, lmax)
Fl_prime = 1./(2.*L+1.)*k/aH * (L*F[L-1]-(L+1)*F[L+1])
Flmax_prime = k/aH*F[lmax-1] - (lmax+1)/aH/tau*F[lmax]
return jnp.concatenate((jnp.array([delta_prime, theta_prime, sigma_prime, F3_prime]), Fl_prime, jnp.array([Flmax_prime])))
[docs]
def output_perturbations(self, lna, modes, args):
return {
"delta": modes[self.first_idx],
"theta": modes[self.first_idx + 1],
"sigma": modes[self.first_idx + 2],
}
[docs]
class MassiveNeutrino(Fluid):
"""
Massive neutrinos fluid species implementation.
Non-relativistic neutrinos with multiple angular momentum modes.
Required input parameters: params['N_nu_massive'], params['T_nu_massive'],
params['m_nu_massive'], params['TCMB0']
Required derived parameters: params['R_nu'], params['om']
Attributes:
-----------
num_ells_per_bin : int
Number of multipole moments per momentum bin for massive neutrino hierarchy
Methods:
--------
rho : Compute massive neutrino density (units: eV cm^{-3})
P : Compute massive neutrino pressure (units: eV cm^{-3})
y_ini : Compute initial perturbation conditions
y_prime : Compute perturbation time derivatives
rho_delta : Compute density perturbation (units: eV cm^{-3})
rho_plus_P_theta : Compute velocity perturbation (units: eV cm^{-3} Mpc^{-1})
rho_plus_P_sigma : Compute shear perturbation (units: eV cm^{-3})
"""
num_ells_per_bin : int = eqx.field(static=True)
q_3p = jnp.array([0.913201, 3.37517, 7.79184])
w_3p = jnp.array([0.0687359, 3.31435, 2.29911])
q_5p = jnp.array([0.583165, 2.0, 4.0, 7.26582, 13.0])
w_5p = jnp.array([0.0081201, 0.689407, 2.8063, 2.05156, 0.12681])
dlfdlq_3p = -q_3p / (1.+jnp.exp(-q_3p)) # Log derivative of fermi-dirac w.r.t. momentum
name = "MassiveNeutrino"
is_matter = True
def __init__(self, first_idx, specs):
super().__init__(first_idx, specs)
self.name = "MassiveNeutrino"
self.is_matter = True
self.num_ells_per_bin = specs["l_max_massive_nu"] + 1
self.num_equations = 3 * self.num_ells_per_bin
[docs]
def rho(self, lna, args):
"""
Compute massive neutrino density.
Parameters:
-----------
lna : float or ArrayLike
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float or ArrayLike
Massive neutrino density (units: eV cm^{-3})
"""
params = args
# Ensure lna is at least 1D for broadcasting
lna_arr = jnp.atleast_1d(lna) # shape (N,)
# shape (N,1):
a = jnp.exp(lna_arr)[:, None]
T = params['T_nu_massive'] * params['TCMB0'] / a
x = params['m_nu_massive'] / T
# q_5p, w_5p are shape (5,), broadcast with (N, 1)
integrand = (1. + jnp.exp(-self.q_5p)) / self.q_5p**2 \
* jnp.sqrt(self.q_5p**2 + x**2) # (N, 5)
# Dot product along last axis with w_5p
integral = jnp.dot(integrand, self.w_5p) # (N,)
rho_val = params['N_nu_massive'] * 4. * T[:, 0]**4 / jnp.pi**2 * integral / cnst.hbar**3 / cnst.c**3
# Remove extra dimension if original input was scalar
return jnp.squeeze(rho_val) if jnp.ndim(lna) == 0 else rho_val
[docs]
def P(self, lna, args):
"""
Compute massive neutrino pressure.
Parameters:
-----------
lna : float or ArrayLike
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float or ArrayLike
Massive neutrino pressure (units: eV cm^{-3})
"""
params = args
# Ensure lna is at least 1D for broadcasting
lna_arr = jnp.atleast_1d(lna) # shape (N,)
# shape (N,1)
a = jnp.exp(lna_arr)[:, None]
T = params['T_nu_massive'] * params['TCMB0'] / a
x = params['m_nu_massive'] / T
# q_5p, w_5p are shape (5,), broadcast with (N, 1)
integrand = (1. + jnp.exp(-self.q_5p)) / jnp.sqrt(self.q_5p**2 + x**2) # (N, 5)
# Dot product along last axis with w_5p
integral = jnp.dot(integrand, self.w_5p) # (N,)
P_val = params['N_nu_massive'] * 4./3. * T[:, 0]**4 / jnp.pi**2 * integral / cnst.hbar**3 / cnst.c**3
# Remove extra dimension if original input was scalar
return jnp.squeeze(P_val) if jnp.ndim(lna) == 0 else P_val
[docs]
def y_ini(self, k, tau_ini, args):
"""
Compute initial conditions for massive neutrino perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
tau_ini : float
Initial conformal time (units: Mpc)
args : dict
Cosmological parameters (params)
Returns:
--------
array
Initial perturbation mode values (units: 1/Mpc for kPsi1, else dimensionless)
"""
params = args
# Initial conditions for massless neutrinos first, needed here.
R_nu = params['R_nu']
delta = - (k*tau_ini)**2/3. * (1.-params["om"]*tau_ini/5.)
theta = - k*(k*tau_ini)**3/36./(4.*R_nu+15.) \
* (4.*R_nu+11.+12.-3.*(8.*R_nu**2+50.*R_nu+275.)/20./(2.*R_nu+15.)*tau_ini*params["om"])
sigma = (k*tau_ini)**2/(45.+12.*R_nu) * 2. * (1.+(4.*R_nu-5.)/4./(2.*R_nu+15.)*tau_ini*params["om"])
bins = []
for i in range(3):
q = self.q_3p[i]
# ZZ : Techniclly Psi1 requires epsilon/q = 1/v, but at early times this should be 1. Should check this accuracy!
first_three = jnp.array([delta/4., theta/3., sigma/2.]) * q / (1.+jnp.exp(-q))
bins.append(jnp.concatenate((first_three, jnp.zeros(self.num_ells_per_bin - 3))))
return jnp.concatenate(bins)
[docs]
def y_prime(self, k, lna, metric_h_prime, metric_eta_prime, y, args):
"""
Compute time derivatives of massive neutrino perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
lna : float
Logarithm of scale factor
metric_h_prime : float
Derivative of metric h
metric_eta_prime : float
Derivative of metric eta
y : array
Current perturbation mode values
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
array
Time derivatives of perturbation modes (units: 1/Mpc for kPsi1, else dimensionless)
"""
BG, params, _, _ = args
a = jnp.exp(lna)
T = params['T_nu_massive'] * params['TCMB0'] / a
x = params['m_nu_massive'] / T
aH = BG.aH(lna, params)
tau = BG.tau(lna)
# Iterate through momentum bins
bins = []
for i in range(3):
q = self.q_3p[i]
epsilon = jnp.sqrt(q**2 + x**2)
dlnf0_dlnq = -q / (1+jnp.exp(-q))
# NOTE: The entries are [Psi0, k * Psi1, Psi2, ...]. If accessing Psi1 make sure to divide out k
L = jnp.arange(self.num_ells_per_bin) + self.first_idx + i*self.num_ells_per_bin
Psi = y[L]
Psi0_prime = -q/epsilon/aH*Psi[1] + metric_h_prime/6. * dlnf0_dlnq
kPsi1_prime = q*k**2/3./epsilon/aH * (Psi[0] - 2.*Psi[2])
Psi2_prime = q*k/5./epsilon/aH * (2.*Psi[1]/k - 3.*Psi[3]) - (metric_h_prime/15. + 2.*metric_eta_prime/5.) * dlnf0_dlnq
# Intermediate hierarchy, 3<=L<lmax
lmax = self.num_ells_per_bin - 1
L_inter = jnp.arange(3, lmax) # Doesn't include lmax.
Psi_inter_prime = q*k/epsilon/aH/(2*L_inter+1) * (L_inter*Psi[L_inter-1] - (L_inter+1)*Psi[L_inter+1])
# lmax mode
Psi_lmax_prime = q*k/aH/epsilon*Psi[lmax-1] - (lmax+1)/aH/tau*Psi[lmax]
# Putting it all together
bins.append(jnp.concatenate((jnp.array([Psi0_prime, kPsi1_prime, Psi2_prime]), Psi_inter_prime, jnp.array([Psi_lmax_prime]))))
return jnp.concatenate(bins)
[docs]
def rho_delta(self, lna, y, args):
"""
Compute massive neutrino density perturbation.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Density perturbation (units: eV cm^{-3})
"""
params = args
a = jnp.exp(lna)
T = params['T_nu_massive'] * params['TCMB0'] / a # (N,)
x = params['m_nu_massive'] / T # (N,)
res = 0.
for i in range(3):
q = self.q_3p[i]
w = self.w_3p[i]
epsilon = jnp.sqrt(q**2 + x**2)
Psi0 = y[self.first_idx + i*self.num_ells_per_bin]
res += w*(1.+jnp.exp(-q))*epsilon/q**2 * Psi0
return params['N_nu_massive'] * res * 4./jnp.pi**2 * T**4 / cnst.hbar**3 / cnst.c**3
[docs]
def rho_plus_P_theta(self, lna, y, args):
"""
Compute massive neutrino velocity perturbation.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Velocity perturbation (units: eV cm^{-3} Mpc^{-1})
"""
params = args
a = jnp.exp(lna)
T = params['T_nu_massive'] * params['TCMB0'] / a # (N,)
x = params['m_nu_massive'] / T # (N,)
res = 0.
for i in range(3):
q = self.q_3p[i]
w = self.w_3p[i]
kPsi1 = y[self.first_idx+1 + i*self.num_ells_per_bin]
res += w*(1.+jnp.exp(-q))/q * kPsi1
return params['N_nu_massive'] * res * 4./jnp.pi**2 * T**4 / cnst.hbar**3 / cnst.c**3
[docs]
def rho_plus_P_sigma(self, lna, y, args):
"""
Compute massive neutrino shear perturbation.
Parameters:
-----------
lna : float
Logarithm of scale factor
y : array
Perturbation mode values
args : dict
Cosmological parameters (params)
Returns:
--------
float
Shear perturbation (units: eV cm^{-3})
"""
params = args
a = jnp.exp(lna)
T = params['T_nu_massive'] * params['TCMB0'] / a # (N,)
x = params['m_nu_massive'] / T # (N,)
res = 0.
for i in range(3):
q = self.q_3p[i]
w = self.w_3p[i]
epsilon = jnp.sqrt(q**2 + x**2)
Psi2 = y[self.first_idx+2 + i*self.num_ells_per_bin]
res += w*(1.+jnp.exp(-q))/epsilon * Psi2
return params['N_nu_massive'] * res * 8./3./jnp.pi**2 * T**4 / cnst.hbar**3 / cnst.c**3
[docs]
def output_perturbations(self, lna, modes, args):
BG, params = args
rho = vmap(self.rho, in_axes=(0, None))(lna, params) # (Nlna,)
rhoP = rho + vmap(self.P, in_axes=(0, None))(lna, params)
rho_delta = vmap(self.rho_delta, in_axes=(0, 1, None))(lna, modes, params) # (Nlna, Nk)
rho_P_theta = vmap(self.rho_plus_P_theta, in_axes=(0, 1, None))(lna, modes, params)
rho_P_sigma = vmap(self.rho_plus_P_sigma, in_axes=(0, 1, None))(lna, modes, params)
return {
"delta": rho_delta / rho[:, None],
"theta": rho_P_theta / rhoP[:, None],
"sigma": rho_P_sigma / rhoP[:, None],
}
[docs]
class Baryon(StandardFluid):
"""
Baryon fluid species implementation.
Non-relativistic baryons with density and velocity perturbations.
Required input parameters: params['omega_b'], params['YHe']
Required derived parameters: params['R_nu'], params['R_b'], params['om']
Methods:
--------
rho : Compute baryon density (units: eV cm^{-3})
P : Compute baryon pressure (units: eV cm^{-3})
cs2 : Compute sound speed squared (units: dimensionless)
mean_mass : Compute mean baryon mass (units: eV)
y_ini : Compute initial perturbation conditions
y_prime : Compute perturbation time derivatives
"""
name = "Baryon"
num_equations = 2
is_matter = True
def __init__(self, first_idx, specs):
super().__init__(first_idx, specs)
self.name = "Baryon"
self.is_matter = True
[docs]
def rho(self, lna, args):
"""
Compute baryon density.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Baryon density (units: eV cm^{-3})
"""
params = args
return params['omega_b'] * (3.*cnst.H0_over_h**2/8./jnp.pi/cnst.G) / jnp.exp(lna)**3
[docs]
def P(self, lna, args):
"""
Compute baryon pressure.
Parameters:
------------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Baryon pressure (units: eV cm^{-3})
Notes:
------
Baryon pressure is neglected, standard practice for SM baryons.
"""
return 0.
[docs]
def cs2(self, lna, args):
"""
Compute sound speed squared.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
float
Sound speed squared (units: dimensionless)
Notes:
------
Adiabatic sound speed squared, taken from M&B Eq. (68).
Although we can neglect the pressure, this term is important for perturbation growth
during recombination. During reionization this cs2 is negative. This is not physical
but it should not matter for cosmology.
"""
BG, params, species_list, species_dict = args
# Get photon class from list
i = species_dict["Photon"]
photon = species_list[i]
Tm = BG.Tm(lna, params) # Baryon temp
Tg = BG.TCMB(lna, params) # Photon temp
mu = self.mean_mass(lna, (BG,params))
R = 4.*photon.rho(lna, params)/3./self.rho(lna, params)
return Tm/mu * (5./3. - 2./3.*mu*R/cnst.me/BG.aH(lna, params)/BG.tau_c(lna, params) * (Tg/Tm - 1.))
[docs]
def mean_mass(self, lna, args):
"""
Compute mean baryon mass at given redshift.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
float
Mean baryon mass (units: eV)
Notes:
------
Defined to be mu = rho_b / n_b = rho_b / (nH + nHe + ne)
"""
BG, params = args
denom = (1.+BG.xe(lna))*(1.-params['YHe']) + cnst.mH / cnst.mHe * params['YHe']
return cnst.mH / denom
[docs]
def y_ini(self, k, tau_ini, args):
"""
Compute initial conditions for baryon perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
tau_ini : float
Initial conformal time (units: Mpc)
args : dict
Cosmological parameters (params)
Returns:
--------
array
Initial perturbation mode values (units: 1/Mpc for theta, else dimensionless)
"""
params = args
delta = -(k*tau_ini)**2/4. * (1.-params["om"]*tau_ini/5.)
theta = - k**4 * tau_ini**3/36. * (1.-3.*(1.+5.*params['R_b']-params['R_nu'])/20./(1.-params['R_nu'])*params["om"]*tau_ini)
return jnp.array([delta, theta])
[docs]
def y_prime(self, k, lna, metric_h_prime, metric_eta_prime, y, args):
"""
Compute time derivatives of baryon perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
lna : float
Logarithm of scale factor
metric_h_prime : float
Derivative of metric h
metric_eta_prime : float
Derivative of metric eta
y : array
Current perturbation mode values
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
array
Time derivatives of perturbation modes (units: 1/Mpc for theta, else dimensionless)
"""
BG, params, species_list, species_dict = args
# Get photon class from list
i = species_dict["Photon"]
photon = species_list[i]
aH = BG.aH(lna, params)
cs2 = self.cs2(lna, args)
R = 4.*photon.rho(lna, params)/3./self.rho(lna, params)
tau_c = BG.tau_c(lna, params)
delta = y[self.first_idx]
theta = y[self.first_idx+1]
theta_g = photon.get_theta(lna, y, args)
delta_prime = -theta/aH-metric_h_prime/2.
theta_prime = -theta + cs2*k**2*delta/aH + R/tau_c/aH*(theta_g-theta)
return jnp.array([delta_prime, theta_prime])
[docs]
def output_perturbations(self, lna, modes, args):
return {
"delta": modes[self.first_idx],
"theta": modes[self.first_idx + 1],
}
[docs]
class Photon(StandardFluid):
"""
Photon fluid species implementation.
Relativistic photons with temperature and polarization Boltzmann hierarchies.
Required input parameters: params['TCMB0']
Required derived parameters: params['R_nu'], params['R_nu'], params['om']
Attributes:
-----------
num_F_ell_modes : int
Number of temperature multipole moments in Boltzmann hierarchy
num_G_ell_modes : int
Number of polarization multipole moments in Boltzmann hierarchy
Methods:
--------
rho : Compute photon density (units: eV cm^{-3})
P : Compute photon pressure (units: eV cm^{-3})
y_ini : Compute initial perturbation conditions
y_prime : Compute perturbation time derivatives
"""
num_F_ell_modes : int = eqx.field(static=True)
num_G_ell_modes : int = eqx.field(static=True)
name = "Photon"
def __init__(self, first_idx, specs):
super().__init__(first_idx, specs)
self.name = "Photon"
self.num_F_ell_modes = specs["l_max_g"] + 1
self.num_G_ell_modes = specs["l_max_pol_g"] + 1
self.num_equations = self.num_F_ell_modes + self.num_G_ell_modes
[docs]
def rho(self, lna, args):
"""
Compute photon density.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Photon density (units: eV cm^{-3})
"""
params = args
a = jnp.exp(lna)
return jnp.pi**2/15. * params["TCMB0"]**4 / a**4 / (cnst.c * cnst.hbar)**3
[docs]
def P(self, lna, args):
"""
Compute photon pressure.
Parameters:
-----------
lna : float
Logarithm of scale factor
args : dict
Cosmological parameters (params)
Returns:
--------
float
Photon pressure (units: eV cm^{-3})
"""
params = args
return self.rho(lna, params)/3.
[docs]
def y_ini(self, k, tau_ini, args):
"""
Compute initial conditions for photon perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
tau_ini : float
Initial conformal time (units: Mpc)
args : dict
Cosmological parameters (params)
Returns:
--------
array
Initial perturbation mode values (units: 1/Mpc for theta, else dimensionless)
"""
params = args
delta = - (k*tau_ini)**2/3. * (1.-params["om"]*tau_ini/5.)
theta = - k**4 * tau_ini**3/36. * (1.-3.*(1.+5.*params['R_b']-params['R_nu'])/20./(1.-params['R_nu'])*params["om"]*tau_ini)
return jnp.concatenate((jnp.array([delta, theta]), jnp.zeros(self.num_equations - 2)))
[docs]
def y_prime(self, k, lna, metric_h_prime, metric_eta_prime, y, args):
"""
Compute time derivatives of photon perturbations.
Parameters:
-----------
k : float
Wavenumber (units: Mpc^{-1})
lna : float
Logarithm of scale factor
metric_h_prime : float
Derivative of metric h
metric_eta_prime : float
Derivative of metric eta
y : array
Current perturbation mode values
args : tuple
Background cosmology and cosmological parameters (BG, params)
Returns:
--------
array
Time derivatives of perturbation modes (units: 1/Mpc for theta, else dimensionless)
"""
BG, params, species_list, species_dict = args
# Get Baryon from list
i = species_dict["Baryon"]
baryon = species_list[i]
aH = BG.aH(lna, params)
tau_c = BG.tau_c(lna, params)
tau = BG.tau(lna)
Flmax = self.num_F_ell_modes-1
Glmax = self.num_G_ell_modes-1
F = lax.dynamic_slice(y, (self.first_idx,), (self.num_F_ell_modes,))
G = lax.dynamic_slice(y, (self.first_idx+self.num_F_ell_modes,), (self.num_G_ell_modes,))
delta = F[0]
theta = F[1]
sigma = F[2]
theta_b = baryon.get_theta(lna, y, args)
delta_prime = -4./3./aH*theta - 2./3.*metric_h_prime
theta_prime = k**2/aH*(delta/4.-sigma) + (theta_b-theta)/aH/tau_c
sigma_prime = 4./15./aH*theta - 3./10.*k/aH*F[3] + 2./15.*metric_h_prime + 4./5.*metric_eta_prime - 9./10./aH/tau_c*sigma + (G[0]+G[2])/20./aH/tau_c
F3_prime = k/7./aH * (6.*sigma - 4.*F[4]) - F[3]/aH/tau_c
# Temperature Boltzmann Hierarchy
L = jnp.arange(4, Flmax) # Excludes the lmax mode
Fl_prime = 1./(2.*L+1.)*k/aH * (L*F[L-1]-(L+1)*F[L+1]) - F[L]/aH/tau_c
Flmax_prime = k/aH*F[Flmax-1] - (Flmax+1)/aH/tau*F[Flmax] - F[Flmax]/aH/tau_c
# Polarization Boltzmann Hierarchy
L = jnp.arange(0, Glmax) # Excludes the lmax mode
Gl_prime = 1./(2.*L+1.)*k/aH * (L*G[L-1]-(L+1)*G[L+1]) - G[L]/aH/tau_c \
+ (2.*sigma+G[0]+G[2])/2./aH/tau_c * jnp.concatenate((jnp.array([1., 0., 0.2]), jnp.zeros(Glmax-3)))
Glmax_prime = k/aH*G[Glmax-1] - (Glmax+1)/aH/tau*G[Glmax] - G[Glmax]/aH/tau_c
return jnp.concatenate((jnp.array([delta_prime, theta_prime, sigma_prime, F3_prime]), Fl_prime, jnp.array([Flmax_prime]), Gl_prime, jnp.array([Glmax_prime])))
[docs]
def output_perturbations(self, lna, modes, args):
return {
"delta": modes[self.first_idx],
"theta": modes[self.first_idx + 1],
"sigma": modes[self.first_idx + 2],
"G0": modes[self.first_idx + self.num_F_ell_modes],
"G2": modes[self.first_idx + self.num_F_ell_modes + 2],
}