Skip to content

Commit

Permalink
fix: Remove TerminalPDEProblem and replace with PIDEProblem
Browse files Browse the repository at this point in the history
add a try catch because `f` is interpreted differently for DeepBSDE and
DeepSplitting
  • Loading branch information
ashutosh-b-b committed Oct 8, 2023
1 parent 3da04e4 commit 211c120
Showing 1 changed file with 10 additions and 52 deletions.
62 changes: 10 additions & 52 deletions src/HighDimPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,16 @@ module HighDimPDE

isnothing(neumann_bc) ? nothing : @assert eltype(eltype(neumann_bc)) <: eltype(x)
@assert eltype(g(x)) == eltype(x) "Type of `g(x)` must match type of x"
@assert(eltype(f(x, x, g(x), g(x), x, x, p, tspan[1])) == eltype(x),
try
@assert(eltype(f(x, x, g(x), g(x), x, x, p, tspan[1])) == eltype(x),
"Type of non linear function `f(x)` must type of x")
catch e
if isa(e, MethodError)
@assert(eltype(f(x, eltype(x)(0.0), x, p, tspan[1])) == eltype(x), "Type of non linear function `f(x)` must type of x")
else
throw(e)
end
end

PIDEProblem{typeof(g(x)),
typeof(g),
Expand All @@ -92,56 +100,6 @@ module HighDimPDE
g(x), g, f, μ, σ, x, tspan, p, x0_sample, neumann_bc, kwargs)
end

"""
$(SIGNATURES)
Defines a semilinear Partial Differential Equation Problem with terminal conditions, of the form
```math
\\begin{aligned}
\\frac{du}{dt} &= \\tfrac{1}{2} \\text{Tr}(\\sigma \\sigma^T) \\Delta u(x, t) + \\mu \\nabla u(x, t) \\\\
&\\quad + f(x, u(x, t), ( \\nabla_x u )(x, t), p, t)
\\end{aligned}
```
with `` u(x,T) = g(x)``.
## Arguments
* `g` : The terminal condition, of the form `g(x, p, t)`.
* `f` : The nonlinear function, of the form `f(x, u(x, t), ∇u(x, t), p, t)`.
* `μ` : drift function, of the form `μ(x, p, t)`.
* `σ` : diffusion function `σ(x, p, t)`.
* `x`: point where `u(x,t)` is approximated. Is required even in the case where `x0_sample` is provided. Determines the dimensionality of the PDE.
* `tspan`: timespan of the problem.
* `p`: the parameter vector.
"""
struct TerminalPDEProblem{G,F,Mu,Sigma,X,T,P,A,UD,NBC,K} <: AbstractPDEProblem
g::G
f::F
μ::Mu
σ::Sigma
x::X
tspan::Tuple{T,T}
p::P
A::A
x0_sample::UD # the domain of u to be solved
neumann_bc::NBC # neumann boundary conditions, for now not used
kwargs::K
end

function TerminalPDEProblem(g,f,μ,σ,x,tspan,p=nothing;A=nothing,x0_sample=nothing,neumann_bc=nothing,kwargs...)
TerminalPDEProblem{typeof(g),
typeof(f),
typeof(μ),
typeof(σ),
typeof(x),
eltype(tspan),
typeof(p),
typeof(A),
typeof(x0_sample),
typeof(neumann_bc),
typeof(kwargs)}(
g,f,μ,σ,x,tspan,p,A,x0_sample,neumann_bc,kwargs)
end
struct PIDESolution{X0,Ts,L,Us,NNs,Ls}
x0::X0
ts::Ts
Expand Down Expand Up @@ -177,7 +135,7 @@ module HighDimPDE
include("DeepBSDE_Han.jl")
include("MLP.jl")

export PIDEProblem, TerminalPDEProblem, PIDESolution, DeepSplitting, DeepBSDE, MLP
export PIDEProblem, PIDEProblem, PIDESolution, DeepSplitting, DeepBSDE, MLP

export NormalSampling, UniformSampling, NoSampling, solve
end

0 comments on commit 211c120

Please sign in to comment.