Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rand and randn type piracy #387

Merged
merged 18 commits into from
Feb 22, 2025
Merged

Fix rand and randn type piracy #387

merged 18 commits into from
Feb 22, 2025

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Feb 18, 2025

I just came across a seemingly severe case of type piracy in AdvancedHMC. As a first step I just removed it to see what parts are currently making use of it.

Edit: I removed the definition of Base.rand (replaced with branches in the two function where it was used) and added an internal _randn to fix the type piracy of Base.randn (falls back to Base.randn for single RNGs and keeps the current behaviour for vectors of RNGs). Additionally, I renamed the existing internal _rand function to rand_momentum to make it clearer that it's just about sampling a momentum and not similar to _randn.

devmotion and others added 6 commits February 18, 2025 21:34
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@devmotion devmotion mentioned this pull request Feb 19, 2025
@devmotion devmotion marked this pull request as ready for review February 20, 2025 08:40
@yebai yebai requested a review from penelopeysm February 20, 2025 09:28
@penelopeysm
Copy link
Member

@devmotion Could I get you to fix the conflicts? Then I can take a look again.

@devmotion
Copy link
Member Author

Done, I fixed the merge conflict.

Comment on lines 144 to 148
Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic, θ) =
_rand(rng, metric, kinetic) # this disambiguity is required by Random.rand
rand_momentum(rng, metric, kinetic) # this disambiguity is required by Random.rand
Base.rand(rng::AbstractVector{<:AbstractRNG}, metric::AbstractMetric, kinetic, θ) =
_rand(rng, metric, kinetic)
rand_momentum(rng, metric, kinetic)
Base.rand(metric::AbstractMetric, kinetic, θ) = rand(Random.default_rng(), metric, kinetic)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, I renamed the existing internal _rand function to rand_momentum to make it clearer that it's just about sampling a momentum

I'm always hugely in favour of disambiguating names so very glad to see this.

There seem to be a few of these definitions

Base.rand(args...) = rand_momentum(fewer_args...)

Would it make sense to rename those Base.rand to rand_momentum as well? Happy to leave it to another time too, just curious if there's a reason you didn't do that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to remove them but then decided to keep them in this PR because it could (should?) be considered part of the public API whereas _rand etc are surely internal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I thought yes that makes sense, then looking at it more closely I realised the research subdirectory is just a bunch of scripts that use AdvancedHMC rather than part of the module itself. So I guess the name doesn't really matter so much. Feel free to hit merge if you're happy

@devmotion devmotion merged commit b61cbb5 into master Feb 22, 2025
17 checks passed
@devmotion devmotion deleted the dw/rand_piracy branch February 22, 2025 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants