-
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #83 from ashutosh-b-b/bb/nn_kolmogorov
Add NNKolmogorov and NNParamKolmogorov
- Loading branch information
Showing
13 changed files
with
615 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# [The `NNKolmogorov` algorithm](@id nn_komogorov) | ||
|
||
### Problems Supported: | ||
1. [`ParabolicPDEProblem`](@ref) | ||
|
||
```@autodocs | ||
Modules = [HighDimPDE] | ||
Pages = ["NNKolmogorov.jl"] | ||
``` | ||
|
||
`NNKolmogorov` obtains a | ||
- terminal solution for Forward Kolmogorov Equations of the form: | ||
```math | ||
\partial_t u(t,x) = \mu(t, x) \nabla_x u(t,x) + \frac{1}{2} \sigma^2(t, x) \Delta_x u(t,x) | ||
``` | ||
with initial condition given by `g(x)` | ||
- or an initial condition for Backward Kolmogorov Equations of the form: | ||
```math | ||
\partial_t u(t,x) = - \mu(t, x) \nabla_x u(t,x) - \frac{1}{2} \sigma^2(t, x) \Delta_x u(t,x) | ||
``` | ||
with terminal condition given by `g(x)` | ||
|
||
We can use the Feynman-Kac formula : | ||
```math | ||
S_t^x = \int_{0}^{t}\mu(S_s^x)ds + \int_{0}^{t}\sigma(S_s^x)dB_s | ||
``` | ||
And the solution is given by: | ||
```math | ||
f(T, x) = \mathbb{E}[g(S_T^x)] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# [The `NNParamKolmogorov` algorithm](@id nn_komogorov) | ||
|
||
### Problems Supported: | ||
1. [`ParabolicPDEProblem`](@ref) | ||
|
||
```@autodocs | ||
Modules = [HighDimPDE] | ||
Pages = ["NNParamKolmogorov.jl"] | ||
``` | ||
|
||
`NNParamKolmogorov` obtains a | ||
- terminal solution for parametric families of Forward Kolmogorov Equations of the form: | ||
```math | ||
\partial_t u(t,x) = \mu(t, x, γ_mu) \nabla_x u(t,x) + \frac{1}{2} \sigma^2(t, x, γ_sigma) \Delta_x u(t,x) | ||
``` | ||
with initial condition given by `g(x, γ_phi)` | ||
- or an initial condition for parametric families of Backward Kolmogorov Equations of the form: | ||
```math | ||
\partial_t u(t,x) = - \mu(t, x) \nabla_x u(t,x) - \frac{1}{2} \sigma^2(t, x) \Delta_x u(t,x) | ||
``` | ||
with terminal condition given by `g(x, γ_phi)` | ||
|
||
We can use the Feynman-Kac formula : | ||
```math | ||
S_t^x = \int_{0}^{t}\mu(S_s^x)ds + \int_{0}^{t}\sigma(S_s^x)dB_s | ||
``` | ||
And the solution is given by: | ||
```math | ||
f(T, x) = \mathbb{E}[g(S_T^x, γ_phi)] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# `NNKolmogorov` | ||
|
||
## Solving high dimensional Rainbow European Options for a range of initial stock prices: | ||
|
||
```julia | ||
d = 10 # dims | ||
T = 1/12 | ||
sigma = 0.01 .+ 0.03.*Matrix(Diagonal(ones(d))) # volatility | ||
mu = 0.06 # interest rate | ||
K = 100.0 # strike price | ||
function μ_func(du, u, p, t) | ||
du .= mu*u | ||
end | ||
|
||
function σ_func(du, u, p, t) | ||
du .= sigma * u | ||
end | ||
|
||
tspan = (0.0, T) | ||
# The range for initial stock price | ||
xspan = [(98.00, 102.00) for i in 1:d] | ||
|
||
g(x) = max(maximum(x) -K, 0) | ||
|
||
sdealg = EM() | ||
# provide `x0` as nothing to the problem since we are provinding a range for `x0`. | ||
prob = ParabolicPDEProblem(μ_func, σ_func, nothing, tspan, g = g, xspan = xspan) | ||
opt = Flux.Optimisers.Adam(0.01) | ||
alg = NNKolmogorov(m, opt) | ||
m = Chain(Dense(d, 16, elu), Dense(16, 32, elu), Dense(32, 16, elu), Dense(16, 1)) | ||
sol = solve(prob, alg, sdealg, verbose = true, dt = 0.01, | ||
dx = 0.0001, trajectories = 1000, abstol = 1e-6, maxiters = 300) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# `NNParamKolmogorov` | ||
|
||
## Solving Parametric Family of High Dimensional Heat Equation. | ||
|
||
In this example we will solve the high dimensional heat equation over a range of initial values, and also over a range of thermal diffusivity. | ||
```julia | ||
d = 10 | ||
# models input is `d` for initial values, `d` for thermal diffusivity, and last dimension is for stopping time. | ||
m = Chain(Dense(d + 1 + 1, 32, relu), Dense(32, 16, relu), Dense(16, 8, relu), Dense(8, 1)) | ||
ensemblealg = EnsembleThreads() | ||
γ_mu_prototype = nothing | ||
γ_sigma_prototype = zeros(1, 1) | ||
γ_phi_prototype = nothing | ||
|
||
sdealg = EM() | ||
tspan = (0.00, 1.00) | ||
trajectories = 100000 | ||
function phi(x, y_phi) | ||
sum(x .^ 2) | ||
end | ||
function sigma_(dx, x, γ_sigma, t) | ||
dx .= γ_sigma[:, :, 1] | ||
end | ||
mu_(dx, x, γ_mu, t) = dx .= 0.00 | ||
|
||
xspan = [(0.00, 3.00) for i in 1:d] | ||
|
||
p_domain = (p_sigma = (0.00, 2.00), p_mu = nothing, p_phi = nothing) | ||
p_prototype = (p_sigma = γ_sigma_prototype, p_mu = γ_mu_prototype, p_phi = γ_phi_prototype) | ||
dps = (p_sigma = 0.1, p_mu = nothing, p_phi = nothing) | ||
|
||
dt = 0.01 | ||
dx = 0.01 | ||
opt = Flux.Optimisers.Adam(5e-2) | ||
|
||
prob = ParabolicPDEProblem(mu_, | ||
sigma_, | ||
nothing, | ||
tspan; | ||
g = phi, | ||
xspan, | ||
p_domain = p_domain, | ||
p_prototype = p_prototype) | ||
|
||
sol = solve(prob, NNParamKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, | ||
abstol = 1e-10, dx = 0.1, trajectories = trajectories, maxiters = 1000, | ||
use_gpu = false, dps = dps) | ||
``` | ||
Similarly we can parametrize the drift function `mu_` and the initial function `g`, and obtain a solution over all parameters and initial values. | ||
|
||
# Inferring on the solution from `NNParamKolmogorov`: | ||
```julia | ||
x_test = rand(xspan[1][1]:0.1:xspan[1][2], d) | ||
p_sigma_test = rand(p_domain.p_sigma[1]:dps.p_sigma:p_domain.p_sigma[2], 1, 1) | ||
t_test = rand(tspan[1]:dt:tspan[2], 1, 1) | ||
p_mu_test = nothing | ||
p_phi_test = nothing | ||
``` | ||
```julia | ||
sol.ufuns(x_test, t_test, p_sigma_test, p_mu_test, p_phi_test) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
Algorithm for solving Kolmogorov Equations. | ||
```julia | ||
HighDimPDE.NNKolmogorov(chain, opt) | ||
``` | ||
Arguments: | ||
- `chain`: A Chain neural network with a d-dimensional output. | ||
- `opt`: The optimizer to train the neural network. Defaults to `ADAM(0.1)`. | ||
[1]Beck, Christian, et al. "Solving stochastic differential equations and Kolmogorov equations by means of deep learning." arXiv preprint arXiv:1806.00421 (2018). | ||
""" | ||
struct NNKolmogorov{C, O} <: HighDimPDEAlgorithm | ||
chain::C | ||
opt::O | ||
end | ||
NNKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNKolmogorov(chain, opt) | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
Returns a `PIDESolution` object. | ||
# Arguments | ||
- `sdealg`: a SDE solver from [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/solvers/sde_solve/). | ||
If not provided, the plain vanilla [DeepBSDE](https://arxiv.org/abs/1707.02568) method will be applied. | ||
If provided, the SDE associated with the PDE problem will be solved relying on | ||
methods from DifferentialEquations.jl, using [Ensemble solves](https://diffeq.sciml.ai/stable/features/ensemble/) | ||
via `sdealg`. Check the available `sdealg` on the | ||
[DifferentialEquations.jl doc](https://diffeq.sciml.ai/stable/solvers/sde_solve/). | ||
- `maxiters`: The number of training epochs. Defaults to `300` | ||
- `trajectories`: The number of trajectories simulated for training. Defaults to `100` | ||
- Extra keyword arguments passed to `solve` will be further passed to the SDE solver. | ||
""" | ||
function DiffEqBase.solve(prob::ParabolicPDEProblem, | ||
pdealg::HighDimPDE.NNKolmogorov, | ||
sdealg; | ||
ensemblealg = EnsembleThreads(), | ||
abstol = 1.0f-6, | ||
verbose = false, | ||
maxiters = 300, | ||
trajectories = 1000, | ||
save_everystep = false, | ||
use_gpu = false, | ||
dt, | ||
dx, | ||
kwargs...) | ||
tspan = prob.tspan | ||
sigma = prob.σ | ||
μ = prob.μ | ||
|
||
noise_rate_prototype = get(prob.kwargs, :noise_rate_prototype, nothing) | ||
phi = prob.g | ||
|
||
xspan = prob.kwargs.xspan | ||
|
||
xspans = isa(xspan, Tuple) ? [xspan] : xspan | ||
|
||
d = length(xspans) | ||
ts = tspan[1]:dt:tspan[2] | ||
xs = map(xspans) do xspan | ||
xspan[1]:dx:xspan[2] | ||
end | ||
N = size(ts) | ||
T = tspan[2] | ||
|
||
#hidden layer | ||
chain = pdealg.chain | ||
opt = pdealg.opt | ||
ps = Flux.params(chain) | ||
xi = mapreduce(x -> rand(x, 1, trajectories), vcat, xs) | ||
#Finding Solution to the SDE having initial condition xi. Y = Phi(S(X , T)) | ||
sdeproblem = SDEProblem(μ, | ||
sigma, | ||
xi[:, 1], | ||
tspan, | ||
noise_rate_prototype = noise_rate_prototype) | ||
|
||
function prob_func(prob, i, repeat) | ||
SDEProblem(prob.f, | ||
xi[:, i], | ||
prob.tspan, | ||
noise_rate_prototype = prob.noise_rate_prototype) | ||
end | ||
output_func(sol, i) = (sol.u[end], false) | ||
ensembleprob = EnsembleProblem(sdeproblem, | ||
prob_func = prob_func, | ||
output_func = output_func) | ||
sim = solve(ensembleprob, | ||
sdealg, | ||
ensemblealg, | ||
dt = dt, | ||
trajectories = trajectories, | ||
adaptive = false) | ||
|
||
x_sde = Array(sim) | ||
|
||
y = reduce(hcat, phi.(eachcol(x_sde))) | ||
|
||
y = use_gpu ? y |> gpu : y | ||
xi = use_gpu ? xi |> gpu : xi | ||
|
||
#MSE Loss Function | ||
loss(m, x, y) = Flux.mse(m(x), y) | ||
|
||
losses = AbstractFloat[] | ||
|
||
opt_state = Flux.setup(opt, chain) | ||
for epoch in 1:maxiters | ||
gs = Flux.gradient(chain) do model | ||
loss(model, xi, y) | ||
end | ||
Flux.update!(opt_state, chain, gs[1]) | ||
l = loss(chain, xi, y) | ||
@info "Current Epoch: $epoch Current Loss: $l" | ||
push!(losses, l) | ||
end | ||
# Flux.train!(loss, chain, data, opt; cb = callback) | ||
chainout = chain(xi) | ||
xi, chainout | ||
return PIDESolution(xi, ts, losses, chainout, chain, nothing) | ||
end |
Oops, something went wrong.