Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Cosmology.from_sigma8() for a_stop!=1 #31

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pmwd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pmwd.gravity import laplace, neg_grad, gravity
from pmwd.modes import white_noise, linear_modes
from pmwd.lpt import lpt
from pmwd.nbody import nbody
from pmwd.nbody import nbody, nbody_scan
try:
from pmwd._version import __version__
except ModuleNotFoundError:
Expand Down
1 change: 1 addition & 0 deletions pmwd/boltzmann.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def varlin_integ(cosmo, conf):
``(len(conf.varlin_R),)`` and ``conf.cosmo_dtype``.

"""
# no scale factor specified here!!!
Plin = linear_power(conf.var_tophat.x, None, cosmo, conf)

_, varlin = conf.var_tophat(Plin, extrap=True)
Expand Down
8 changes: 7 additions & 1 deletion pmwd/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,17 @@ def a_nbody(self):
"""N-body time integration scale factor steps, including ``a_start``, of ``cosmo_dtype``."""
return jnp.linspace(self.a_start, self.a_stop, num=1+self.a_nbody_num,
dtype=self.cosmo_dtype)

@property
def growth_a_num(self):
"""Number of growth factor points."""
return math.ceil(1. / self.a_lpt_maxstep)

@property
def growth_a(self):
"""Growth function scale factors, for both LPT and N-body, of ``cosmo_dtype``."""
return jnp.concatenate((self.a_lpt, self.a_nbody[1:]))
return jnp.linspace(0., 1., num=self.growth_a_num,
dtype=self.cosmo_dtype)

@property
def varlin_R(self):
Expand Down
2 changes: 1 addition & 1 deletion pmwd/cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def sigma8(self):
"""Linear matter rms overdensity within a tophat sphere of 8 Mpc/h radius at a=1."""
from pmwd.boltzmann import varlin
R = 8 * self.conf.Mpc_SI / self.conf.L
return jnp.sqrt(varlin(R, 1, self, self.conf))
return jnp.sqrt(varlin(R, 1., self, self.conf))

@property
def ptcl_mass(self):
Expand Down
52 changes: 52 additions & 0 deletions pmwd/nbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from jax import value_and_grad, jit, vjp, custom_vjp
import jax.numpy as jnp
from jax.lax import scan
from jax.tree_util import tree_map

from pmwd.boltzmann import growth
Expand Down Expand Up @@ -216,12 +217,29 @@ def nbody_step(a_prev, a_next, ptcl, obsvbl, cosmo, conf):
def nbody(ptcl, obsvbl, cosmo, conf, reverse=False):
"""N-body time integration."""
a_nbody = conf.a_nbody[::-1] if reverse else conf.a_nbody
a_nbody_arr = jnp.array([a_nbody[:-1], a_nbody[1:]]).T

ptcl, obsvbl = nbody_init(a_nbody[0], ptcl, obsvbl, cosmo, conf)

for a_prev, a_next in zip(a_nbody[:-1], a_nbody[1:]):
ptcl, obsvbl = nbody_step(a_prev, a_next, ptcl, obsvbl, cosmo, conf)
return ptcl, obsvbl

@partial(custom_vjp, nondiff_argnums=(4,))
def nbody_scan(ptcl, obsvbl, cosmo, conf, reverse=False):
"""N-body time integration. Use jax.lax.scan to speed up the compilation."""
a_nbody = conf.a_nbody[::-1] if reverse else conf.a_nbody
a_nbody_arr = jnp.array([a_nbody[:-1], a_nbody[1:]]).T

ptcl, obsvbl = nbody_init(a_nbody[0], ptcl, obsvbl, cosmo, conf)

def _nbody_step(carry, x):
ptcl, obsvbl = carry
a_prev, a_next = x
ptcl, obsvbl = nbody_step(a_prev, a_next, ptcl, obsvbl, cosmo, conf)
return (ptcl, obsvbl), None
(ptcl, obsvbl), _ = scan(_nbody_step, (ptcl, obsvbl), a_nbody_arr)
return ptcl, obsvbl

@jit
def nbody_adj_init(a, ptcl, ptcl_cot, obsvbl_cot, cosmo, conf):
Expand Down Expand Up @@ -259,11 +277,35 @@ def nbody_adj(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf, reverse=False):
a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf)
return ptcl, ptcl_cot, cosmo_cot

def nbody_adj_scan(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf, reverse=False):
"""N-body time integration with adjoint equation. Use jax.lax.scan to speed up the compilation."""
a_nbody = conf.a_nbody[::-1] if reverse else conf.a_nbody
a_nbody_arr = jnp.array([a_nbody[:0:-1], a_nbody[-2::-1]]).T


ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = nbody_adj_init(
a_nbody[-1], ptcl, ptcl_cot, obsvbl_cot, cosmo, conf)

def _nbody_adj_step(carry, x):
ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = carry
a_prev, a_next = x
ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = nbody_adj_step(
a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf)
return (ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force), None

(ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force), _ = scan(_nbody_adj_step, (ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force), a_nbody_arr)

return ptcl, ptcl_cot, cosmo_cot


def nbody_fwd(ptcl, obsvbl, cosmo, conf, reverse):
ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf, reverse)
return (ptcl, obsvbl), (ptcl, cosmo, conf)

def nbody_fwd_scan(ptcl, obsvbl, cosmo, conf, reverse):
ptcl, obsvbl = nbody_scan(ptcl, obsvbl, cosmo, conf, reverse)
return (ptcl, obsvbl), (ptcl, cosmo, conf)

def nbody_bwd(reverse, res, cotangents):
ptcl, cosmo, conf = res
ptcl_cot, obsvbl_cot = cotangents
Expand All @@ -273,4 +315,14 @@ def nbody_bwd(reverse, res, cotangents):

return ptcl_cot, obsvbl_cot, cosmo_cot, None

def nbody_bwd_scan(reverse, res, cotangents):
ptcl, cosmo, conf = res
ptcl_cot, obsvbl_cot = cotangents

ptcl, ptcl_cot, cosmo_cot = nbody_adj_scan(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf,
reverse=reverse)

return ptcl_cot, obsvbl_cot, cosmo_cot, None

nbody.defvjp(nbody_fwd, nbody_bwd)
nbody_scan.defvjp(nbody_fwd_scan, nbody_bwd_scan)