Skip to content

Commit

Permalink
style changes, removed _log file, optimized lpmf
Browse files Browse the repository at this point in the history
  • Loading branch information
chvandorp committed Dec 16, 2023
1 parent 22d9eb2 commit beb70f3
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 38 deletions.
1 change: 0 additions & 1 deletion stan/math/prim/prob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
#include <stan/math/prim/prob/dirichlet_lpdf.hpp>
#include <stan/math/prim/prob/dirichlet_lpmf.hpp>
#include <stan/math/prim/prob/dirichlet_rng.hpp>
#include <stan/math/prim/prob/dirichlet_multinomial_log.hpp>
#include <stan/math/prim/prob/dirichlet_multinomial_lpmf.hpp>
#include <stan/math/prim/prob/dirichlet_multinomial_rng.hpp>
#include <stan/math/prim/prob/discrete_range_ccdf_log.hpp>
Expand Down
32 changes: 0 additions & 32 deletions stan/math/prim/prob/dirichlet_multinomial_log.hpp

This file was deleted.

7 changes: 4 additions & 3 deletions stan/math/prim/prob/dirichlet_multinomial_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ return_type_t<T_prior_size> dirichlet_multinomial_lpmf(

auto ops_partials = make_partials_propagator(alpha_ref);
if (!is_constant_all<T_prior_size>::value) {
partials<0>(ops_partials) = digamma(alpha_val + ns_array)
- digamma(alpha_val) + digamma(a_sum)
- digamma(a_sum + n_sum);
partials<0>(ops_partials)
= (ns_array > 0)
.select(digamma(alpha_val + ns_array) - digamma(alpha_val), 0.0)
+ digamma(a_sum) - digamma(a_sum + n_sum);
}
return ops_partials.build(lp);
}
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/prob/dirichlet_multinomial_rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ inline std::vector<int> dirichlet_multinomial_rng(
check_nonnegative(function, "number of trials variables", N);

// special case N = 0 would lead to an exception thrown by multinomial_rng
if (N == 0)
if (N == 0) {
return std::vector<int>(alpha.size(), 0);
}

// sample a simplex theta from the Dirichlet distribution
auto theta = dirichlet_rng(alpha_ref, rng);
Expand Down
6 changes: 5 additions & 1 deletion test/unit/math/mix/prob/dirichlet_multinomial_test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <test/unit/math/test_ad.hpp>

TEST(ProbDistributions, dirichlet_multinomial) {
TEST(ProbDistributionsDirichletMultinomial, dirichlet_multinomial) {
// bind integer vector arg because can't autodiff through
auto f = [](const std::vector<int>& y) {
return [=](const auto& alpha) {
Expand All @@ -11,6 +11,8 @@ TEST(ProbDistributions, dirichlet_multinomial) {

std::vector<int> y1 = {1, 2, 3, 4};
std::vector<int> y2 = {30, 75, 409, 34};
// test if zero counts are handled correctly
std::vector<int> y3 = {0, 5, 0, 10};

Eigen::VectorXd alpha1(4);
alpha1 << 1.0, 2.0, 3.0, 4.0;
Expand All @@ -25,9 +27,11 @@ TEST(ProbDistributions, dirichlet_multinomial) {
EXPECT_NO_THROW(f(y1)(alpha2));
EXPECT_NO_THROW(f(y2)(alpha1));
EXPECT_NO_THROW(f(y2)(alpha2));
EXPECT_NO_THROW(f(y3)(alpha1));

stan::test::expect_ad(f(y1), alpha1);
stan::test::expect_ad(f(y1), alpha2);
stan::test::expect_ad(f(y2), alpha1);
stan::test::expect_ad(f(y2), alpha2);
stan::test::expect_ad(f(y3), alpha1);
}

0 comments on commit beb70f3

Please sign in to comment.