Skip to content

Commit

Permalink
Small clarification in train-kernel-parameters notebook (#496)
Browse files Browse the repository at this point in the history
* Small clarification in train-kernel-parameters notebook

* set remaining `passes` to true for piecewisepolynomial tests - now passing
  • Loading branch information
st-- authored Mar 25, 2023
1 parent 00d36a9 commit 401d556
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions examples/train-kernel-parameters/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ plot!(x_test, sinc; label="true function")
# A simple way to ensure that the kernel parameters are positive
# is to optimize over the logarithm of the parameters.

function kernelcall(θ)
function kernel_creator(θ)
return (exp(θ[1]) * SqExponentialKernel() + exp(θ[2]) * Matern32Kernel())
ScaleTransform(exp(θ[3]))
end
Expand All @@ -52,7 +52,7 @@ nothing #hide
# the kernel parameters and normalization constant:

function f(x, x_train, y_train, θ)
k = kernelcall(θ[1:3])
k = kernel_creator(θ[1:3])
return kernelmatrix(k, x, x_train) *
((kernelmatrix(k, x_train) + exp(θ[4]) * I) \ y_train)
end
Expand Down Expand Up @@ -133,15 +133,15 @@ raw_initial_θ = (
flat_θ, unflatten = ParameterHandling.value_flatten(raw_initial_θ)
flat_θ #hide

# We define a few relevant functions and note that compared to the previous `kernelcall` function, we do not need explicit `exp`s.
# We define a few relevant functions and note that compared to the previous `kernel_creator` function, we do not need explicit `exp`s.

function kernelcall(θ)
function kernel_creator(θ)
return.k1 * SqExponentialKernel() + θ.k2 * Matern32Kernel()) ScaleTransform.k3)
end
nothing #hide

function f(x, x_train, y_train, θ)
k = kernelcall(θ)
k = kernel_creator(θ)
return kernelmatrix(k, x, x_train) *
((kernelmatrix(k, x_train) + θ.noise_var * I) \ y_train)
end
Expand Down
8 changes: 4 additions & 4 deletions test/basekernels/piecewisepolynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
nothing,
example_inputs(StableRNG(123456), T)...;
passes=(
unary=(true, true, false),
binary=(true, true, false),
diag_unary=(true, true, false),
diag_binary=(true, true, false),
unary=(true, true, true),
binary=(true, true, true),
diag_unary=(true, true, true),
diag_binary=(true, true, true),
),
)
end
Expand Down

0 comments on commit 401d556

Please sign in to comment.