Skip to content

Commit

Permalink
remove recursion altogether
Browse files Browse the repository at this point in the history
  • Loading branch information
atmyers committed Feb 6, 2025
1 parent 6d7b889 commit 19e0de4
Showing 1 changed file with 39 additions and 35 deletions.
74 changes: 39 additions & 35 deletions Src/Base/AMReX_Random.H
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,39 @@ namespace amrex
#endif
}

namespace {

AMREX_GPU_DEVICE AMREX_FORCE_INLINE
Real RandomGamma_alpha_ge_1 (Real alpha, Real beta, RandomEngine const& random_engine)
{
AMREX_ASSERT(alpha >= 1);
AMREX_ASSERT(beta > 0);

Real x, v, u;
Real d = alpha - 1.0_rt / 3.0_rt;
Real c = (1.0_rt / 3.0_rt) / std::sqrt(d);

while (1) {
do {
x = amrex::RandomNormal(0.0_rt, 1.0_rt, random_engine);
v = 1.0_rt + c * x;
} while (v <= 0.0_rt);

v = v * v * v;
u = amrex::Random(random_engine);

if (u < 1.0_rt - 0.0331_rt * x * x * x * x) {
break;
}

if (std::log(u) < 0.5_rt * x * x + d * (1.0_rt - v + std::log(v))) {
break;
}
}
return beta * d * v;
}
}

/**
* \brief Generate a psuedo-random floating point number from the Gamma distribution
*
Expand All @@ -132,7 +165,6 @@ namespace amrex
*/
Real RandomGamma (Real alpha, Real beta);

template <int depth = 0>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Real RandomGamma (Real alpha, Real beta, RandomEngine const& random_engine)
{
Expand All @@ -142,45 +174,17 @@ namespace amrex
AMREX_IF_ON_DEVICE((
if (alpha < 1)
{
if constexpr (depth == 0)
{
// note that alpha is assumed to be > 0, so this function will recurse at most once.
Real u = amrex::Random(random_engine);
return RandomGamma<1>(1.0_rt + alpha, beta, random_engine) * std::pow(u, 1.0_rt / alpha);
}
}

{
Real x, v, u;
Real d = alpha - 1.0_rt / 3.0_rt;
Real c = (1.0_rt / 3.0_rt) / std::sqrt(d);

while (1) {
do {
x = amrex::RandomNormal(0.0_rt, 1.0_rt, random_engine);
v = 1.0_rt + c * x;
} while (v <= 0.0_rt);

v = v * v * v;
u = amrex::Random(random_engine);

if (u < 1.0_rt - 0.0331_rt * x * x * x * x) {
break;
}

if (std::log(u) < 0.5_rt * x * x + d * (1.0_rt - v + std::log(v))) {
break;
}
}
return beta * d * v;
Real u = amrex::Random(random_engine);
return RandomGamma_alpha_ge_1(1.0_rt + alpha, beta, random_engine) * std::pow(u, 1.0_rt / alpha);
} else {
RandomGamma_alpha_ge_1(alpha, beta, random_engine);
}
))

AMREX_IF_ON_HOST((
amrex::ignore_unused(random_engine);
return RandomGamma(alpha, beta);
amrex::ignore_unused(random_engine);
return RandomGamma(alpha, beta);
))

}

/**
Expand Down

0 comments on commit 19e0de4

Please sign in to comment.