Source code for abcmb.ABCMBTools

"""
Script for helper numerical tools
"""
import jax
from jax import grad, lax, config, jit, vmap
from jax.scipy.special import gamma, factorial
from functools import partial
import numpy as np
import jax.numpy as jnp
import equinox as eqx

config.update("jax_enable_x64", True)

### BEGINNING OF WIGNER ROTATION FOR LENSING ###

[docs] def wigner_d_matrix(mu, ells, m, n): """ Compute Wigner d-matrix elements for rotation. Recursively computes reduced Wigner d-matrix elements d^ell_{mn}(beta) for CMB lensing calculations using three-term recurrence relation. Parameters: ----------- mu : array Cosine of rotation angle beta ells : array Multipole values [m, m+1, m+2, ..., ellmax] m : int First index (must be positive and >= abs(n)) n : int Second index (must satisfy abs(n) <= m) Returns: -------- array Wigner d-matrix elements, shape (len(mu), len(ells)) """ # base case: ell = m def base_val(mu): beta = jnp.arccos(mu) norm = jnp.sqrt((2*m+1)/2) * jnp.sqrt(factorial(2*m)/(factorial(m+n)*factorial(m-n))) return norm * jnp.cos(beta/2.)**(m+n)*(-jnp.sin(beta/2.))**(m-n) normA = jnp.sqrt((2*ells+3)/(2*ells+1)) normC = jnp.sqrt((2*ells+3)/(2*ells-1)) denom = jnp.sqrt((ells+1)**2-m**2) * jnp.sqrt((ells+1)**2-n**2) A = jnp.nan_to_num(normA * (ells+1)*(2*ells+1) / denom, 0) B = jnp.nan_to_num(-A * m * n / ells / (ells+1), 0) C = jnp.nan_to_num(-normC * jnp.sqrt(ells**2-m**2) * jnp.sqrt(ells**2-n**2) / denom * (ells+1)/ells, 0) def one_mu(mu): d_start = base_val(mu) # Corresponds to ellmin = m def recursive_dlp1(carry, inputs): # For the first iteration, will take d^m_{mn} and d^m_{mn}=0., compute d^{m+1}_{mn}. dl, dlm1 = carry a, b, c = inputs # Compute dlp1 dlp1 = a*mu*dl + b*dl + c*dlm1 # Save dl, then make dl->dlm1, dlp1->dl return (dlp1, dl), dl # run scan for l = 2..lmax-1 (_, _), res = lax.scan(recursive_dlp1, (d_start, 0.), (A, B, C)) return res * jnp.sqrt(2./(2.*ells+1)) return vmap(one_mu)(mu)
[docs] def d00(mu, ells): """ Compute Wigner d-matrix elements d^ell_{00}. Parameters: ----------- mu : array Cosine of rotation angle ells : array Multipole values starting from ell=2 Returns: -------- array d^ell_{00} elements for ells >= 2 """ # ells go from (2, 3, 4, ..., ellmax) ells_patched = jnp.concatenate((jnp.array([0, 1]), ells)) res = wigner_d_matrix(mu, ells_patched, 0, 0) return res[:, 2:] # Return only the ells >= 2
[docs] def d1n(mu, ells, n): """ Compute Wigner d-matrix elements d^ell_{1n}. Parameters: ----------- mu : array Cosine of rotation angle ells : array Multipole values n : int Second index (abs(n) <= 1) Returns: -------- array d^ell_{1n} elements """ # Wigner matrices where m=1, and abs(n)<=m. ells_patched = jnp.concatenate((jnp.array([1]), ells)) res = wigner_d_matrix(mu, ells_patched, 1, n) return res[:, 1:]
[docs] def d2n(mu, ells, n): """ Compute Wigner d-matrix elements d^ell_{2n}. Parameters: ----------- mu : array Cosine of rotation angle ells : array Multipole values n : int Second index (abs(n) <= 2) Returns: -------- array d^ell_{2n} elements """ # Wigner matrices where m=2, and abs(n)<=m. res = wigner_d_matrix(mu, ells, 2, n) return res
[docs] def d3n(mu, ells, n): """ Compute Wigner d-matrix elements d^ell_{3n}. Parameters: ----------- mu : array Cosine of rotation angle ells : array Multipole values n : int Second index (abs(n) <= 3) Returns: -------- array d^ell_{3n} elements, zero-padded for ell < 3 """ # Wigner matrices where m=3, and abs(n)<=m. ells_sliced = ells[1:] # Compute starting at ell=3 res = wigner_d_matrix(mu, ells_sliced, 3, n) res_patched = jnp.concatenate((jnp.zeros((mu.size, 1)), res), axis=1) # Pad zeros for ell<3. return res_patched
[docs] def d4n(mu, ells, n): """ Compute Wigner d-matrix elements d^ell_{4n}. Parameters: ----------- mu : array Cosine of rotation angle ells : array Multipole values n : int Second index (abs(n) <= 4) Returns: -------- array d^ell_{4n} elements, zero-padded for ell < 4 """ # Wigner matrices where m=4, and abs(n)<=m. ells_sliced = ells[2:] # Compute starting at ell=4 res = wigner_d_matrix(mu, ells_sliced, 4, n) res_patched = jnp.concatenate((jnp.zeros((mu.size, 2)), res), axis=1) # Pad zeros for ell<4. return res_patched
### END OF WIGNER ROTATION FOR LENSING ### ### LENSING INTEGRAL QUADRATURE METHODS ### def _pn_and_pnm1_scan(z, n): """ Return P_n(z), P_{n-1}(z), Legendre polynomials for vector z using lax.scan. Used in function below to find quadrature roots and weights. """ z = jnp.asarray(z) p1 = jnp.ones_like(z) # P_0 p2 = jnp.zeros_like(z) # P_{-1} (dummy) def step(carry, j): p1, p2 = carry # recurrence: # new_p1 = P_j, new_p2 = P_{j-1} new_p1 = ((2.0*j - 1.0) * z * p1 - (j - 1.0) * p2) / j new_p2 = p1 return (new_p1, new_p2), None (p_n, p_nm1), _ = lax.scan(step, (p1, p2), jnp.arange(1, n+1)) return p_n, p_nm1
[docs] def gauss_legendre_weights(n, tol=1.e-16, max_it=50): """ Iteratively finds the roots and weights for Gauss-Legendre quadrature integration between -1 and 1, given the number of roots n requested. Parameters: ----------- n : int Number of roots desired, typically set by lmax of the lensed power spectrum. tol : jnp.float64 Accuracy tolerance on the Newton root finder. max_it : int Maximum iteration on the Newton root finder. Returns: -------- (mu, w) : (jnp.array, jnp.array) The roots mu and weights w. """ dtype=jnp.float64 m = (n + 1) // 2 i = jnp.arange(1, m + 1, dtype=dtype) z0 = jnp.cos(jnp.array(jnp.pi, dtype=dtype) * (i - 0.25) / (n + 0.5)) def newton_step(z): p_n, p_nm1 = _pn_and_pnm1_scan(z, n) pp = n * (z * p_n - p_nm1) / (z*z - 1.0) # P_n'(z) z_new = z - p_n / pp return z_new, jnp.max(jnp.abs(z_new - z)), pp def cond(state): z, err, it = state return jnp.logical_and(err > tol, it < max_it) def body(state): z, err, it = state z_new, err_new, _ = newton_step(z) return (z_new, err_new, it + 1) # init z1, err1, _ = newton_step(z0) z, err, it = lax.while_loop(cond, body, (z1, err1, jnp.array(1))) # final derivative for weights p_n, p_nm1 = _pn_and_pnm1_scan(z, n) pp = n * (z * p_n - p_nm1) / (z*z - 1.0) w_half = 2.0 / ((1.0 - z*z) * pp * pp) # match your C layout: mu[i-1] = -z(i), mu[n-i] = z(i) mu = jnp.empty((n,), dtype=dtype) w = jnp.empty((n,), dtype=dtype) mu = mu.at[:m].set(-z) mu = mu.at[n-m:].set(z[::-1]) w = w.at[:m].set(w_half) w = w.at[n-m:].set(w_half[::-1]) return mu, w
[docs] def fast_interp(x, xp_min, xp_max, fp): """ Fast 1D linear interpolation for uniformly-spaced grids. Optimized interpolation that avoids searchsorted by exploiting uniform grid spacing. Significantly faster than jnp.interp for large arrays. Parameters: ----------- x : float or array Query points for interpolation xp_min : float Minimum value of interpolation grid xp_max : float Maximum value of interpolation grid fp : array Function values on uniform grid Returns: -------- float or array Interpolated values at query points Notes: ------ Credit: JAX issue #16182 (https://github.com/jax-ml/jax/issues/16182) Assumes fp is uniformly spaced between xp_min and xp_max. """ # The official jnp.interp is very slow becuase it uses searchsorted. # Therefore, we leverage the fact that the fp is linearly increasing, evenly spaced, and has a known range # to make this operation much faster. eps = 1.e-6 n = fp.shape[-1] i = (x - xp_min) / (xp_max - xp_min) * (n - 1) # fix bug in JAX issue i = jnp.clip(i, eps, n - 1.0 - eps) # Avoid index out of range i_lower = jnp.floor(i).astype(jnp.int32) i_upper = jnp.minimum(i_lower + 1, n - 1) w_upper = i - i_lower w_lower = 1.0 - w_upper return w_lower * fp[i_lower] + w_upper * fp[i_upper]
[docs] def bilinear_interp(x, y, z, xq, yq): """ Bilinear interpolation on 2D regular grid. Performs bilinear interpolation to evaluate function at query point (xq, yq) given values on a regular 2D grid. Parameters: ----------- x : array 1D array of x-coordinates (must be sorted) y : array 1D array of y-coordinates (must be sorted) z : array 2D array of function values, shape (len(y), len(x)) xq : float Query x-coordinate yq : float Query y-coordinate Returns: -------- float Interpolated value at (xq, yq) Notes: ------ Uses standard bilinear interpolation formula with four nearest grid points. """ # find indices for x and y ix = jnp.clip(jnp.searchsorted(x, xq) - 1, 0, x.size - 2) iy = jnp.clip(jnp.searchsorted(y, yq) - 1, 0, y.size - 2) # grid corner points x0, x1 = x[ix], x[ix + 1] y0, y1 = y[iy], y[iy + 1] # fractional positions tx = (xq - x0) / (x1 - x0) ty = (yq - y0) / (y1 - y0) # get z values z00 = z[iy, ix] z01 = z[iy, ix + 1] z10 = z[iy + 1, ix] z11 = z[iy + 1, ix + 1] # bilinear interpolation return (1 - tx) * (1 - ty) * z00 + tx * (1 - ty) * z01 + (1 - tx) * ty * z10 + tx * ty * z11