Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify calculations of tables for DirichletProcess and PitmanYorProcess #970

Merged
merged 2 commits into from
Nov 18, 2019

Conversation

devmotion
Copy link
Member

This PR simplifies the calculations of the tables for DirichletProcess and PitmanYorProcess by avoiding redundant summations and loops over the provided cluster counts. Additionally, the explicit definition of the output arrays (and in particular its element types) ensures that the function is type-stable.

Moreover, in the sampling from the Chinese restaurant process the use of the in-place softmax function reduces allocations.

@xukai92 xukai92 requested a review from trappmartin November 17, 2019 20:43
@torfjelde
Copy link
Member

I just did a quick scan and things look nice, but have you checked differentiability? To me the "old" code looks like it's made an explicit effort to make things work with AD, and for example the use of inplace softmax! won't.

Maybe @trappmartin can chime in on whether or not we actually need to support AD in this case:)

@devmotion
Copy link
Member Author

Ah, good point, I haven't checked that (I just ran the regular tests of the processes, IIRC). If a specialized implementation for AD is required, then one could dispatch depending on the parameter type of the processes. I guess it would still make sense to optimize the existing code even in that case, e.g., by avoiding redundant summations and using softmax (I guess that should be fine and probably be slightly more performant and numerically stable than the current version).

Copy link
Member

@trappmartin trappmartin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the effort, look good. Please seem my comments.

Regarding the softmax!, I don't see any problem here as it occurs only in the rand function. Did the test of the Dirichlet process run through? If so, it might be necessary to improve the tests as you seem to have a bug in your implementation.


# pre-calculations
dα = d.α
z = log(sum_m) - 1 + dα
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong. It should probably be z = log(sum_m - 1 + dα)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, yes, I'll fix it.

mi = m[i]

if iszero(mi)
if contains_zero
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here, so that people know why table[i] will be set to -Inf.

# pre-calculations
dα = d.α
z = log(sum_m) - 1 + dα
table_zero = log(dα) - z
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this pre-allocation necessary? We anyhow set only one table entry to this value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not required, I just thought it might be easier to see that the same value is added to the table if m contains a 0 (at the index of the first 0) and if it doesn't (then it's pushed to the end). Of course, I can just duplicate the line for both cases, then no precalculation is required.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion about this but it feels somewhat unnecessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong feelings about it either, so I just remove it.

dθ = d.θ
dd = d.d
z = log(sum_m + dθ)
table_zero = log(dθ + dd * d.t) - z
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, do we need this pre-allocation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention was the same as above.

@devmotion
Copy link
Member Author

I rechecked the results of the tests in stdlib/RandomMeasures.jl (which are commented out by default), on master, the initial commit in this PR which contains two bugs, and the updated version of this PR in which these bugs are fixed. In all cases all tests in the distributions testset pass, however, all other testsets chinese restaurant processes, stick breaking, and size-based resampling error due to a KeyError when calling Libtask.consume(::Trace). I assume that due to these test errors (which exist on master as well and which I haven't paid too much attention to yet) I did miss the bugs in the initial commit.

@trappmartin
Copy link
Member

Thanks! I’ll take a look at the tests.
Thanks for the PR.

@trappmartin trappmartin merged commit 89e859c into TuringLang:master Nov 18, 2019
@devmotion devmotion deleted the randommeasures branch November 18, 2019 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants