FAQ
What is this JAX thing, anyway? Do I need to be a JAX expert to use ABCMB?
You should not need to know very much about JAX in order to use ABCMB. JAX is a Python library that lets you write fast and differentiable Python code without having to depart too far from the ordinary Python you’re already familiar with. In general, the following tips are helpful to keep in mind:
Always use
jax.numpyas opposed tonumpy. (At the top of your script, you canimport jax.numpy as jnpand usejnpin place of where you might ordinarily usenp.) Scipy is also not generally safe to use with JAX; usejax.scipy, or if the function you need is missing fromjax.scipyyou can look into community packages likeinterpax,diffrax, orquadax.You typically can’t write conditionals the way you would in Python in JAX. Conditions that aren’t based on things like floats (e.g.
if FLAG==True, andFLAGis set at initialization) are just fine to use with JAX. But if a function in your custom fluid needs a conditional likeif x > 5:, usejnp.whereinstead. The Python code:if x > 5: return x**2 else: return -x**2
Can be rewritten as the more JAX-friendly:
return jnp.where(x>5, x**2, -x**2)
See JAX documentation for more common gotchas if you’re finding your custom modules throw errors or are recompiling.
How do I take gradients of ABCMB output?
In general it is best to use jax.jacfwd, or forward accumulation, with ABCMB. There are many internal states to trace over, which can quickly push memory requirements out of hand, when attempting to use reverse AD like jax.grad or jax.jacrev with ABCMB.
Why am I seeing my code recompile?
There are a few reasons why otherwise JAX-safe code might not call the cached JIT-compiled version. Passing in different data types to the same JIT-compiled argument will trigger recompilation (i.e. passing in “1” vs “1.”).
You may also be seeing recompilation because you wrapped Model.run_cosmology in a larger jit context. We do not recommend enclosing Model.run_cosmology in another jax.jit, for a couple reasons:
add_derived_parameters, the first auxiliary function to be called under the hood, is intended to be called outside ofjit. This in principle can be worked around by wrapping your inputs to your exterior-mostjitcontext injnp.array.LINX is CPU-optimized and has been carefully extracted from the rest of the ABCMB
jitcontext so that it will always run on CPU, regardless of whether a GPU is present. WrappingModel.run_cosmologyin a largerjitcontext will slow down your code substantially if you are running with BBN. Future versions may also force CPU evaluation of HyRex in a similar fashion, so you will always be taking a performance hit if you choose tojitModel.run_cosmology.
Finally, you may be seeing recompilation because you’ve encountered a bug! After you’ve ruled out the causes above, feel free to open an issue on our GitHub. If you’d like to explore the cause yourself, turn on jax.config.update("explain_cache_misses"=True) before running your recompiling code.
Can I add new methods to my custom fluids beyond what ABCMB expects?
Yes! abcmb.species.Baryon is a good example of a fluid that has extra methods.
Help! My new fluid is breaking the differential equation solver!
There are a couple of initialization parameters that you can adjust if you find your new cosmology is giving the solver a hard time. If you’re seeing diffrax reached its max_steps, you can increase this parameter with the max_steps_PE initialization parameter to your Model. It defulats to 2048, but reasonable extensions to LCDM can require 4096 or sometimes even more.
If increasing max_steps_PE doesn’t help, you can also try adjusting the relative and absolute tolerances of the solver. These are rtol_small_k_PE, rtol_large_k_PE (relative tolerances at small and large k, respectively), and atol_small_k_PE and atol_large_k_PE (absolute tolerances at small and large k). Their defults are 1e-5, 1e-4, 1e-10, and 1e-6, respectively; making these 1-2 orders of magnitude smaller or larger may help, though note your accuracy will be reduced if the tolerance is too large.