-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move Libtask to an extension #75
Conversation
We probably need to think about the handling of the rng in the non-libtask case |
Pull Request Test Coverage Report for Build 6448601669
💛 - Coveralls |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #75 +/- ##
==========================================
- Coverage 96.32% 95.60% -0.72%
==========================================
Files 7 8 +1
Lines 381 410 +29
==========================================
+ Hits 367 392 +25
- Misses 14 18 +4
☔ View full report in Codecov by Sentry. |
|
||
Base.copy(model::LibtaskModel) = LibtaskModel(model.f, copy(model.ctask)) | ||
|
||
const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can access this LibtaskModel
externally. Hence, users can not use this libtask extension even when Libtask
is loaded. To fix this, we can consider overloading a function defined in AdvancedPS
.
const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R} | |
const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R} | |
AdvancedPS.Trace(Libtask, f, args; __rng__ = Random.default_rng()) = AdvancedPS.Trace(LibtaskModel(f, args...), rng=__rng__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not possible (for general Julia versions - maybe at all?) to dispatch on modules. One will be able to use them e.g. as type parameters only in 1.10: JuliaLang/julia#47749
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible to access types by retrieving the extension module with get_extension
. But that does not lead to a good API IMO and also adding accessible types is not the intended use case of extensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @devmotion -- that's very helpful to know. However, in this case, we can dispatch any type. For example:
const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R} | |
const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R} | |
AdvancedPS.Trace(Val{:libtask}, f, args; __rng__ = Random.default_rng()) = AdvancedPS.Trace(LibtaskModel(f, args...), rng=__rng__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to hide all the internals away from the user, not sure they should know about what Trace
is. The main package implements sample(::AbstractStateSpace)
and the extension overloads it with sample(::AbstractModel)
.
This works only if you load Libtask
(and enable the extension):
struct Model <: AbstractMCMC.AbstractModel end
function (model::Model)(rng::AbstractRNG)
x = rand(...)
end
sample(rng, Model(), PG(100), 1_000)
test/container.jl
Outdated
# # Test task copy version of trace | ||
# trng = AdvancedPS.TracedRNG() | ||
# tmodel = AdvancedPS.GenericModel(f2, trng) | ||
# tr = AdvancedPS.Trace(tmodel, trng) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the additional constructor proposed in https://github.com/TuringLang/AdvancedPS.jl/pull/75/files#r1284737656, we should be able to uncomment these tests.
Excellent work -- I left some suggestions above. |
631900a
to
006c4fd
Compare
Just realized while fixing some of tests that the AdvancedPS.jl/src/container.jl Line 227 in 72a1e55
update_ref! should probably dispatch on the type of the sampler somehow.
|
Something like this would actually would use resample the ancestor particle for both chains_pg = sample(model, pg, Nₛ);
chains_pgas = sample(model, pgas, Nₛ); |
Not sure I follow here, can you clarity? |
006c4fd
to
806c10a
Compare
Still a draft but we can track the progress on moving Libtask to its own extension