diff --git a/src/DeepBSDE.jl b/src/DeepBSDE.jl index 56d17f9..c421b21 100644 --- a/src/DeepBSDE.jl +++ b/src/DeepBSDE.jl @@ -59,7 +59,7 @@ end DeepBSDE(u0, σᵀ∇u; opt = Flux.Optimise.Adam(0.1)) = DeepBSDE(u0, σᵀ∇u, opt) """ -$(SIGNATURES) +$(TYPEDSIGNATURES) Returns a `PIDESolution` object. @@ -73,6 +73,8 @@ Returns a `PIDESolution` object. [DifferentialEquations.jl doc](https://diffeq.sciml.ai/stable/solvers/sde_solve/). - `limits`: if `true`, upper and lower limits will be calculated, based on [Deep Primal-Dual algorithm for BSDEs](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=3071506). +- `maxiters`: The number of training epochs. Defaults to `300` +- `trajectories`: The number of trajectories simulated for training. Defaults to `100` - Extra keyword arguments passed to `solve` will be further passed to the SDE solver. """ function DiffEqBase.solve(prob::ParabolicPDEProblem, diff --git a/src/DeepBSDE_Han.jl b/src/DeepBSDE_Han.jl index eb785b7..75744c8 100644 --- a/src/DeepBSDE_Han.jl +++ b/src/DeepBSDE_Han.jl @@ -1,4 +1,15 @@ # called whenever sdealg is not specified. +""" +$(TYPEDSIGNATURES) + +Returns a `PIDESolution` object. + +# Arguments: +- `maxiters`: The number of training epochs. Defaults to `300` +- `trajectories`: The number of trajectories simulated for training. Defaults to `100` + +To use [SDE Algorithms](https://diffeq.sciml.ai/stable/solvers/sde_solve/) use [`DeepBSDE`](@ref) +""" function DiffEqBase.solve(prob::ParabolicPDEProblem, alg::DeepBSDE; dt, diff --git a/src/DeepSplitting.jl b/src/DeepSplitting.jl index 9541c52..91ca483 100644 --- a/src/DeepSplitting.jl +++ b/src/DeepSplitting.jl @@ -51,7 +51,7 @@ function DeepSplitting(nn; end """ -$(SIGNATURES) +$(TYPEDSIGNATURES) Returns a `PIDESolution` object. diff --git a/src/MLP.jl b/src/MLP.jl index a6e7e62..c640649 100644 --- a/src/MLP.jl +++ b/src/MLP.jl @@ -21,7 +21,7 @@ end MLP(; M = 4, L = 4, K = 10, mc_sample = NoSampling()) = MLP(M, L, K, mc_sample) """ -$(SIGNATURES) +$(TYPEDSIGNATURES) Returns a `PIDESolution` object. diff --git a/src/NNStopping.jl b/src/NNStopping.jl index 39d1793..0f49f25 100644 --- a/src/NNStopping.jl +++ b/src/NNStopping.jl @@ -54,6 +54,22 @@ function (model::NNStoppingModelArray)(X, G) broadcast((x, m) -> m(x), eachslice(XG, dims = 2)[2:end], model.ms) end +""" +$(TYPEDSIGNATURES) + +Returns a NamedTuple with `payoff` and `stopping_time` + +Arguments: +- `sdealg`: a SDE solver from [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/solvers/sde_solve/). + If not provided, the plain vanilla [DeepBSDE](https://arxiv.org/abs/1707.02568) method will be applied. + If provided, the SDE associated with the PDE problem will be solved relying on + methods from DifferentialEquations.jl, using [Ensemble solves](https://diffeq.sciml.ai/stable/features/ensemble/) + via `sdealg`. Check the available `sdealg` on the + [DifferentialEquations.jl doc](https://diffeq.sciml.ai/stable/solvers/sde_solve/). +- `maxiters`: The number of training epochs. Defaults to `300` +- `trajectories`: The number of trajectories simulated for training. Defaults to `100` +- Extra keyword arguments passed to `solve` will be further passed to the SDE solver. +""" function DiffEqBase.solve(prob::ParabolicPDEProblem, pdealg::NNStopping, sdealg;