-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify calculations of tables for DirichletProcess and PitmanYorProcess #970
Conversation
I just did a quick scan and things look nice, but have you checked differentiability? To me the "old" code looks like it's made an explicit effort to make things work with AD, and for example the use of inplace Maybe @trappmartin can chime in on whether or not we actually need to support AD in this case:) |
Ah, good point, I haven't checked that (I just ran the regular tests of the processes, IIRC). If a specialized implementation for AD is required, then one could dispatch depending on the parameter type of the processes. I guess it would still make sense to optimize the existing code even in that case, e.g., by avoiding redundant summations and using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the effort, look good. Please seem my comments.
Regarding the softmax!
, I don't see any problem here as it occurs only in the rand function. Did the test of the Dirichlet process run through? If so, it might be necessary to improve the tests as you seem to have a bug in your implementation.
src/stdlib/RandomMeasures.jl
Outdated
|
||
# pre-calculations | ||
dα = d.α | ||
z = log(sum_m) - 1 + dα |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks wrong. It should probably be z = log(sum_m - 1 + dα)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, yes, I'll fix it.
mi = m[i] | ||
|
||
if iszero(mi) | ||
if contains_zero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment here, so that people know why table[i]
will be set to -Inf
.
src/stdlib/RandomMeasures.jl
Outdated
# pre-calculations | ||
dα = d.α | ||
z = log(sum_m) - 1 + dα | ||
table_zero = log(dα) - z |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this pre-allocation necessary? We anyhow set only one table entry to this value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not required, I just thought it might be easier to see that the same value is added to the table if m
contains a 0 (at the index of the first 0) and if it doesn't (then it's pushed to the end). Of course, I can just duplicate the line for both cases, then no precalculation is required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion about this but it feels somewhat unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have strong feelings about it either, so I just remove it.
src/stdlib/RandomMeasures.jl
Outdated
dθ = d.θ | ||
dd = d.d | ||
z = log(sum_m + dθ) | ||
table_zero = log(dθ + dd * d.t) - z |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, do we need this pre-allocation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intention was the same as above.
I rechecked the results of the tests in |
Thanks! I’ll take a look at the tests. |
This PR simplifies the calculations of the tables for DirichletProcess and PitmanYorProcess by avoiding redundant summations and loops over the provided cluster counts. Additionally, the explicit definition of the output arrays (and in particular its element types) ensures that the function is type-stable.
Moreover, in the sampling from the Chinese restaurant process the use of the in-place softmax function reduces allocations.