Skip to content

Commit

Permalink
tests ok
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreMartinon committed Dec 2, 2024
1 parent b1b9827 commit c2ce2b1
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 20 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,22 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[weakdeps]
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[extensions]
CTBasePlots = "Plots"
CTBaseLoadSave = ["JLD2", "JSON3"]
CTBasePlots = "Plots"

[compat]
DataStructures = "0.18"
DifferentiationInterface = "0.5"
DocStringExtensions = "0.9"
ForwardDiff = "0.10"
Interpolations = "0.15"
JLD2 = "0.5"
JSON3 = "1"
MLStyle = "0.4"
MacroTools = "0.5"
Parameters = "0.12"
Expand Down
136 changes: 136 additions & 0 deletions src/optimal_control_solution-setters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,142 @@ function OptimalControlSolution(
)
end


"""
$(TYPEDSIGNATURES)
Build OCP functional solution from discrete solution (given as raw variables and multipliers plus some optional infos)
"""
function OptimalControlSolution(
ocp::OptimalControlModel,
T,
X,
U,
v,
P;
objective = 0,
iterations = 0,
constraints_violation = 0,
message = "No msg",
stopping = nothing,
success = nothing,
constraints_types = (nothing, nothing, nothing, nothing, nothing),
constraints_mult = (nothing, nothing, nothing, nothing, nothing),
box_multipliers = (nothing, nothing, nothing, nothing, nothing, nothing),
)
dim_x = state_dimension(ocp)
dim_u = control_dimension(ocp)
dim_v = variable_dimension(ocp)

# check that time grid is strictly increasing
# if not proceed with list of indexes as time grid
if !issorted(T, lt = <=)
println(
"WARNING: time grid at solution is not strictly increasing, replacing with list of indices...",
)
println(T)
dim_NLP_steps = length(T) - 1
T = LinRange(0, dim_NLP_steps, dim_NLP_steps + 1)
end

# variables: remove additional state for lagrange cost
x = ctinterpolate(T, matrix2vec(X[:, 1:dim_x], 1))
p = ctinterpolate(T[1:(end - 1)], matrix2vec(P[:, 1:dim_x], 1))
u = ctinterpolate(T, matrix2vec(U[:, 1:dim_u], 1))

# force scalar output when dimension is 1
fx = (dim_x == 1) ? deepcopy(t -> x(t)[1]) : deepcopy(t -> x(t))
fu = (dim_u == 1) ? deepcopy(t -> u(t)[1]) : deepcopy(t -> u(t))
fp = (dim_x == 1) ? deepcopy(t -> p(t)[1]) : deepcopy(t -> p(t))
var = (dim_v == 1) ? v[1] : v

# misc infos
infos = Dict{Symbol, Any}()
infos[:constraints_violation] = constraints_violation

# nonlinear constraints and multipliers
control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[1], 1))(t)
mult_control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[1], 1))(t)
state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[2], 1))(t)
mult_state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[2], 1))(t)
mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[3], 1))(t)
mult_mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[3], 1))(t)

# boundary and variable constraints
boundary_constraints = constraints_types[4]
mult_boundary_constraints = constraints_mult[4]
variable_constraints = constraints_types[5]
mult_variable_constraints = constraints_mult[5]

# box constraints multipliers
mult_state_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[1][:, 1:dim_x], 1))(t)
mult_state_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[2][:, 1:dim_x], 1))
mult_control_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[3][:, 1:dim_u], 1))(t)
mult_control_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[4][:, 1:dim_u], 1))
mult_variable_box_lower, mult_variable_box_upper = box_multipliers[5], box_multipliers[6]

# build and return solution
if is_variable_dependent(ocp)
return OptimalControlSolution(
ocp;
state = fx,
control = fu,
objective = objective,
costate = fp,
time_grid = T,
variable = var,
iterations = iterations,
stopping = stopping,
message = message,
success = success,
infos = infos,
control_constraints = control_constraints,
state_constraints = state_constraints,
mixed_constraints = mixed_constraints,
boundary_constraints = boundary_constraints,
variable_constraints = variable_constraints,
mult_control_constraints = mult_control_constraints,
mult_state_constraints = mult_state_constraints,
mult_mixed_constraints = mult_mixed_constraints,
mult_boundary_constraints = mult_boundary_constraints,
mult_variable_constraints = mult_variable_constraints,
mult_state_box_lower = mult_state_box_lower,
mult_state_box_upper = mult_state_box_upper,
mult_control_box_lower = mult_control_box_lower,
mult_control_box_upper = mult_control_box_upper,
mult_variable_box_lower = mult_variable_box_lower,
mult_variable_box_upper = mult_variable_box_upper,
)
else
return OptimalControlSolution(
ocp;
state = fx,
control = fu,
objective = objective,
costate = fp,
time_grid = T,
iterations = iterations,
stopping = stopping,
message = message,
success = success,
infos = infos,
control_constraints = control_constraints,
state_constraints = state_constraints,
mixed_constraints = mixed_constraints,
boundary_constraints = boundary_constraints,
mult_control_constraints = mult_control_constraints,
mult_state_constraints = mult_state_constraints,
mult_mixed_constraints = mult_mixed_constraints,
mult_boundary_constraints = mult_boundary_constraints,
mult_state_box_lower = mult_state_box_lower,
mult_state_box_upper = mult_state_box_upper,
mult_control_box_lower = mult_control_box_lower,
mult_control_box_upper = mult_control_box_upper,
)
end
end


# setters
#state!(sol::OptimalControlSolution, state::Function) = (sol.state = state; nothing)
#control!(sol::OptimalControlSolution, control::Function) = (sol.control = control; nothing)
Expand Down
8 changes: 6 additions & 2 deletions src/optimal_control_solution-type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,9 @@ export export_ocp_solution
export import_ocp_solution

# placeholders (see extension CTBaseLoadSave)
function export_ocp_solution end
function import_ocp_solution end
function export_ocp_solution(args...; kwargs...)
error("Requires JLD2 and JSON3 packages")
end
function import_ocp_solution(args...; kwargs...)
error("Requires JLD2 and JSON3 packages")
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Aqua
using CTBase
using DifferentiationInterface: AutoForwardDiff
using Plots
using JLD2, JSON3
using Test

# functions and types that are not exported
Expand Down
Binary file added test/solution_test.jld2
Binary file not shown.
108 changes: 108 additions & 0 deletions test/solution_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"time_grid": [
0,
0.1111111111111111,
0.2222222222222222,
0.3333333333333333,
0.4444444444444444,
0.5555555555555556,
0.6666666666666666,
0.7777777777777778,
0.8888888888888888,
1
],
"objective": 1,
"control": [
0,
0.2222222222222222,
0.4444444444444444,
0.6666666666666666,
0.8888888888888888,
1.1111111111111112,
1.3333333333333333,
1.5555555555555556,
1.7777777777777777,
2
],
"costate": [
[
0,
-1
],
[
0.1111111111111111,
-0.8888888888888888
],
[
0.2222222222222222,
-0.7777777777777778
],
[
0.3333333333333333,
-0.6666666666666667
],
[
0.4444444444444444,
-0.5555555555555556
],
[
0.5555555555555556,
-0.4444444444444444
],
[
0.6666666666666666,
-0.33333333333333337
],
[
0.7777777777777778,
-0.2222222222222222
],
[
0.8888888888888888,
-0.11111111111111116
]
],
"variable": null,
"state": [
[
0,
1
],
[
0.1111111111111111,
1.1111111111111112
],
[
0.2222222222222222,
1.2222222222222223
],
[
0.3333333333333333,
1.3333333333333333
],
[
0.4444444444444444,
1.4444444444444444
],
[
0.5555555555555556,
1.5555555555555556
],
[
0.6666666666666666,
1.6666666666666665
],
[
0.7777777777777778,
1.7777777777777777
],
[
0.8888888888888888,
1.8888888888888888
],
[
1,
2
]
]
}
28 changes: 12 additions & 16 deletions test/test_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ function test_solution()
end

times = range(0, 1, 10)
x = t -> t
x = t -> [t, t+1]
u = t -> 2t
p = t -> t
p = t -> [t, t-1]
obj = 1
sol = OptimalControlSolution(
ocp;
Expand All @@ -33,6 +33,12 @@ function test_solution()
@test all(control_discretized(sol) .== u.(times))
@test all(costate_discretized(sol) .== p.(times))

# test export / read solution in JSON format (NB. requires time grid in solution !)
println(sol.time_grid)
export_ocp_solution(sol; filename_prefix = "solution_test", format = :JSON)
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test", format = :JSON)
@test sol.objective == sol_reloaded.objective

# NonFixed ocp
@def ocp begin
v R, variable
Expand All @@ -45,7 +51,7 @@ function test_solution()
(0.5u(t)^2) min
end

x = t -> t
x = t -> [t, t+1]
u = t -> 2t
obj = 1
v = 1
Expand All @@ -55,19 +61,9 @@ function test_solution()
@test typeof(sol) == OptimalControlSolution
@test_throws UndefKeywordError OptimalControlSolution(ocp; x, u, obj)


# test save / load solution in JLD2 format
@testset verbose = true showtiming = true ":save_load :JLD2" begin
export_ocp_solution(sol; filename_prefix = "solution_test")
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test")
@test sol.objective == sol_reloaded.objective
end

# test export / read solution in JSON format
@testset verbose = true showtiming = true ":export_read :JSON" begin
export_ocp_solution(sol; filename_prefix = "solution_test", format = :JSON)
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test", format = :JSON)
@test sol.objective == sol_reloaded.objective
end
export_ocp_solution(sol; filename_prefix = "solution_test")
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test")
@test sol.objective == sol_reloaded.objective

end

0 comments on commit c2ce2b1

Please sign in to comment.