From e3d3f8ed35556b2fc29921a493846f6f2cff0bba Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Tue, 11 Feb 2025 17:24:12 -0800 Subject: [PATCH] Update examples to use new implicit solver interface --- .buildkite/pipeline.yml | 4 - docs/src/matrix_fields.md | 16 +- examples/hybrid/driver.jl | 1 - examples/hybrid/implicit_equation_jacobian.jl | 113 ++++ examples/hybrid/ode_config.jl | 2 +- examples/hybrid/schur_complement_W.jl | 232 -------- .../hybrid/staggered_nonhydrostatic_model.jl | 552 ++++++++---------- examples/implicit_solver_debugging_tools.jl | 56 -- src/Fields/broadcast.jl | 3 + src/MatrixFields/band_matrix_row.jl | 19 +- src/MatrixFields/field_matrix_solver.jl | 4 +- src/MatrixFields/field_matrix_with_solver.jl | 17 +- src/MatrixFields/field_name.jl | 15 + src/MatrixFields/field_name_dict.jl | 238 +++++--- src/MatrixFields/field_name_set.jl | 19 +- src/MatrixFields/single_field_solver.jl | 15 +- test/MatrixFields/field_names.jl | 319 ++++++---- test/Operators/finitedifference/linsolve.jl | 101 ---- test/runtests.jl | 1 - 19 files changed, 843 insertions(+), 884 deletions(-) create mode 100644 examples/hybrid/implicit_equation_jacobian.jl delete mode 100644 examples/hybrid/schur_complement_W.jl delete mode 100644 examples/implicit_solver_debugging_tools.jl delete mode 100644 test/Operators/finitedifference/linsolve.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 6c268742f2..a39f562cff 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -617,10 +617,6 @@ steps: key: unit_wfact command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/finitedifference/wfact.jl" - - label: "Unit: linsolve" - key: unit_linsolve - command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/finitedifference/linsolve.jl" - - label: "Unit: fd tensor" key: unit_fd_tensor command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/finitedifference/tensor.jl" diff --git a/docs/src/matrix_fields.md b/docs/src/matrix_fields.md index a08f86ee2f..5bca76464e 100644 --- a/docs/src/matrix_fields.md +++ b/docs/src/matrix_fields.md @@ -26,7 +26,16 @@ MultiplyColumnwiseBandMatrixField operator_matrix ``` -# Linear Solvers +## Vectors and Matrices of Fields + +``@docs +FieldNameDict +identity_field_matrix +field_vector_view +concrete_field_vector +`` + +## Linear Solvers ```@docs FieldMatrixSolverAlgorithm @@ -42,7 +51,7 @@ StationaryIterativeSolve ApproximateBlockArrowheadIterativeSolve ``` -# Preconditioners +## Preconditioners ```@docs PreconditionerAlgorithm @@ -67,9 +76,6 @@ FieldName @name FieldNameTree FieldNameSet -FieldNameDict -field_vector_view -concrete_field_vector is_lazy lazy_main_diagonal lazy_mul diff --git a/examples/hybrid/driver.jl b/examples/hybrid/driver.jl index b55a960bee..4cc1bf42eb 100644 --- a/examples/hybrid/driver.jl +++ b/examples/hybrid/driver.jl @@ -54,7 +54,6 @@ using JLD2 const FT = get(ENV, "FLOAT_TYPE", "Float32") == "Float32" ? Float32 : Float64 -include("../implicit_solver_debugging_tools.jl") include("../ordinary_diff_eq_bug_fixes.jl") include("../common_spaces.jl") diff --git a/examples/hybrid/implicit_equation_jacobian.jl b/examples/hybrid/implicit_equation_jacobian.jl new file mode 100644 index 0000000000..03f60a1158 --- /dev/null +++ b/examples/hybrid/implicit_equation_jacobian.jl @@ -0,0 +1,113 @@ +import LinearAlgebra: ldiv! +using ClimaCore: Spaces, Fields, Operators +using ClimaCore.Utilities: half +using ClimaCore.MatrixFields +using ClimaCore.MatrixFields: @name + +struct ImplicitEquationJacobian{TJ, RJ, F, T1, T2} + # nonzero blocks of the implicit tendency's Jacobian + ∂Yₜ∂Y::TJ + + # the full implicit residual's Jacobian, and its linear solver + ∂R∂Y::RJ + + # whether this struct is used to compute Wfact_t or Wfact + transform::Bool + + # flags for computing the Jacobian + flags::F + + # cache that is used to evaluate ldiv! + temp1::T1 + temp2::T2 +end + +function ImplicitEquationJacobian(Y, transform, flags) + FT = eltype(Y) + + ᶜρ_name = @name(c.ρ) + ᶜ𝔼_name = if :ρθ in propertynames(Y.c) + @name(c.ρθ) + elseif :ρe in propertynames(Y.c) + @name(c.ρe) + elseif :ρe_int in propertynames(Y.c) + @name(c.ρe_int) + end + ᶠ𝕄_name = @name(f.w) + + BidiagonalRow_C3 = BidiagonalMatrixRow{C3{FT}} + BidiagonalRow_ACT3 = BidiagonalMatrixRow{Adjoint{FT, CT3{FT}}} + QuaddiagonalRow_ACT3 = QuaddiagonalMatrixRow{Adjoint{FT, CT3{FT}}} + TridiagonalRow_C3xACT3 = + TridiagonalMatrixRow{typeof(C3(FT(0)) * CT3(FT(0))')} + ∂ᶜ𝔼ₜ∂ᶠ𝕄_Row_ACT3 = + flags.∂ᶜ𝔼ₜ∂ᶠ𝕄_mode == :exact && :ρe in propertynames(Y.c) ? + QuaddiagonalRow_ACT3 : BidiagonalRow_ACT3 + ∂Yₜ∂Y = FieldMatrix( + (ᶜρ_name, ᶠ𝕄_name) => Fields.Field(BidiagonalRow_ACT3, axes(Y.c)), + (ᶜ𝔼_name, ᶠ𝕄_name) => Fields.Field(∂ᶜ𝔼ₜ∂ᶠ𝕄_Row_ACT3, axes(Y.c)), + (ᶠ𝕄_name, ᶜρ_name) => Fields.Field(BidiagonalRow_C3, axes(Y.f)), + (ᶠ𝕄_name, ᶜ𝔼_name) => Fields.Field(BidiagonalRow_C3, axes(Y.f)), + (ᶠ𝕄_name, ᶠ𝕄_name) => + Fields.Field(TridiagonalRow_C3xACT3, axes(Y.f)), + ) + + dtγ = FT(1) + I = MatrixFields.identity_field_matrix(Y) # one(∂Yₜ∂Y) can't get every block + ∂R∂Y = transform ? I ./ dtγ .- ∂Yₜ∂Y : dtγ .* ∂Yₜ∂Y .- I + alg = MatrixFields.BlockArrowheadSolve(ᶜρ_name, ᶜ𝔼_name) + + return ImplicitEquationJacobian( + ∂Yₜ∂Y, + FieldMatrixWithSolver(∂R∂Y, Y, alg), + transform, + flags, + similar(Y), + similar(Y), + ) +end + +# Required for compatibility with OrdinaryDiffEq.jl +Base.similar(A::ImplicitEquationJacobian) = ImplicitEquationJacobian( + similar(A.∂Yₜ∂Y), + similar(A.∂R∂Y), + A.transform, + A.flags, + A.temp1, + A.temp2, +) + +# Required for compatibility with ClimaTimeSteppers.jl +Base.zero(A::ImplicitEquationJacobian) = ImplicitEquationJacobian( + zero(A.∂Yₜ∂Y), + zero(A.∂R∂Y), + A.transform, + A.flags, + A.temp1, + A.temp2, +) + +# This method for ldiv! is called by Newton's method from ClimaTimeSteppers.jl. +# It solves ∂R∂Y * x = b for x, where R is the implicit residual. +ldiv!( + x::Fields.FieldVector, + A::ImplicitEquationJacobian, + b::Fields.FieldVector, +) = ldiv!(x, A.∂R∂Y, b) + +# This method for ldiv! is called by Krylov.jl from ClimaTimeSteppers.jl. +# See https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/605 for a +# a similar way of handling the AbstractVectors generated by Krylov.jl. +function ldiv!( + x::AbstractVector, + A::ImplicitEquationJacobian, + b::AbstractVector, +) + A.temp_b .= b + ldiv!(A.temp_x, A, A.temp_b) + x .= A.temp_x +end + +# This function is called by Newton's method from OrdinaryDiffEq.jl. +linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve! +_linsolve!(x, A, b, update_matrix = false; kwargs...) = ldiv!(x, A, b) diff --git a/examples/hybrid/ode_config.jl b/examples/hybrid/ode_config.jl index 7c72087b7e..c566c405e7 100644 --- a/examples/hybrid/ode_config.jl +++ b/examples/hybrid/ode_config.jl @@ -14,7 +14,7 @@ use_transform(ode_algo) = !is_imex_CTS_algo(ode_algo) function jac_kwargs(ode_algo, Y, jacobi_flags) if is_imex_CTS_algo(ode_algo) - W = SchurComplementW(Y, use_transform(ode_algo), jacobi_flags) + W = ImplicitEquationJacobian(Y, use_transform(ode_algo), jacobi_flags) if use_transform(ode_algo) return (; jac_prototype = W, Wfact_t = Wfact!) else diff --git a/examples/hybrid/schur_complement_W.jl b/examples/hybrid/schur_complement_W.jl deleted file mode 100644 index e1fc600196..0000000000 --- a/examples/hybrid/schur_complement_W.jl +++ /dev/null @@ -1,232 +0,0 @@ -using LinearAlgebra - -using ClimaCore: Spaces, Fields, Operators -using ClimaCore.Utilities: half - -const compose = Operators.ComposeStencils() -const apply = Operators.ApplyStencil() - -struct SchurComplementW{F, FT, J1, J2, J3, J4, S, T} - # whether this struct is used to compute Wfact_t or Wfact - transform::Bool - - # flags for computing the Jacobian - flags::F - - # reference to dtγ, which is specified by the ODE solver - dtγ_ref::FT - - # nonzero blocks of the Jacobian - ∂ᶜρₜ∂ᶠ𝕄::J1 - ∂ᶜ𝔼ₜ∂ᶠ𝕄::J2 - ∂ᶠ𝕄ₜ∂ᶜ𝔼::J3 - ∂ᶠ𝕄ₜ∂ᶜρ::J3 - ∂ᶠ𝕄ₜ∂ᶠ𝕄::J4 - - # cache for the Schur complement linear solve - S::S - - # whether to test the Jacobian and linear solver - test::Bool - - # cache that is used to evaluate ldiv! - temp1::T - temp2::T -end - -function Base.zero(jac::SchurComplementW) - return SchurComplementW( - jac.transform, - jac.flags, - jac.dtγ_ref, - Base.zero(jac.∂ᶜρₜ∂ᶠ𝕄), - Base.zero(jac.∂ᶜ𝔼ₜ∂ᶠ𝕄), - Base.zero(jac.∂ᶠ𝕄ₜ∂ᶜ𝔼), - Base.zero(jac.∂ᶠ𝕄ₜ∂ᶜρ), - Base.zero(jac.∂ᶠ𝕄ₜ∂ᶠ𝕄), - Base.zero(jac.S), - jac.test, - Base.zero(jac.temp1), - Base.zero(jac.temp2), - ) -end - - -function SchurComplementW(Y, transform, flags, test = false) - FT = eltype(Y) - dtγ_ref = Ref(zero(FT)) - center_space = axes(Y.c) - face_space = axes(Y.f) - - # TODO: Automate this. - J_eltype1 = Operators.StencilCoefs{-half, half, NTuple{2, FT}} - J_eltype2 = - flags.∂ᶜ𝔼ₜ∂ᶠ𝕄_mode == :exact && :ρe in propertynames(Y.c) ? - Operators.StencilCoefs{-(1 + half), 1 + half, NTuple{4, FT}} : J_eltype1 - J_eltype3 = Operators.StencilCoefs{-1, 1, NTuple{3, FT}} - ∂ᶜρₜ∂ᶠ𝕄 = Fields.Field(J_eltype1, center_space) - ∂ᶜ𝔼ₜ∂ᶠ𝕄 = Fields.Field(J_eltype2, center_space) - ∂ᶠ𝕄ₜ∂ᶜ𝔼 = Fields.Field(J_eltype1, face_space) - ∂ᶠ𝕄ₜ∂ᶜρ = Fields.Field(J_eltype1, face_space) - ∂ᶠ𝕄ₜ∂ᶠ𝕄 = Fields.Field(J_eltype3, face_space) - - # TODO: Automate this. - S_eltype = Operators.StencilCoefs{-1, 1, NTuple{3, FT}} - S = Fields.Field(S_eltype, face_space) - N = Spaces.nlevels(face_space) - - SchurComplementW{ - typeof(flags), - typeof(dtγ_ref), - typeof(∂ᶜρₜ∂ᶠ𝕄), - typeof(∂ᶜ𝔼ₜ∂ᶠ𝕄), - typeof(∂ᶠ𝕄ₜ∂ᶜρ), - typeof(∂ᶠ𝕄ₜ∂ᶠ𝕄), - typeof(S), - typeof(Y), - }( - transform, - flags, - dtγ_ref, - ∂ᶜρₜ∂ᶠ𝕄, - ∂ᶜ𝔼ₜ∂ᶠ𝕄, - ∂ᶠ𝕄ₜ∂ᶜ𝔼, - ∂ᶠ𝕄ₜ∂ᶜρ, - ∂ᶠ𝕄ₜ∂ᶠ𝕄, - S, - test, - similar(Y), - similar(Y), - ) -end - -# We only use Wfact, but the implicit/IMEX solvers require us to pass -# jac_prototype, then call similar(jac_prototype) to obtain J and Wfact. Here -# is a temporary workaround to avoid unnecessary allocations. -Base.similar(w::SchurComplementW) = w - -#= -A = [-I 0 dtγ ∂ᶜρₜ∂ᶠ𝕄 ; - 0 -I dtγ ∂ᶜ𝔼ₜ∂ᶠ𝕄 ; - dtγ ∂ᶠ𝕄ₜ∂ᶜρ dtγ ∂ᶠ𝕄ₜ∂ᶜ𝔼 dtγ ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I] = - [-I 0 A13 ; - 0 -I A23 ; - A31 A32 A33 - I] -b = [b1; b2; b3] -x = [x1; x2; x3] -Solving A x = b: - -x1 + A13 x3 = b1 ==> x1 = -b1 + A13 x3 (1) - -x2 + A23 x3 = b2 ==> x2 = -b2 + A23 x3 (2) - A31 x1 + A32 x2 + (A33 - I) x3 = b3 (3) -Substitute (1) and (2) into (3): - A31 (-b1 + A13 x3) + A32 (-b2 + A23 x3) + (A33 - I) x3 = b3 ==> - (A31 A13 + A32 A23 + A33 - I) x3 = b3 + A31 b1 + A32 b2 ==> - x3 = (A31 A13 + A32 A23 + A33 - I) \ (b3 + A31 b1 + A32 b2) -Finally, use (1) and (2) to get x1 and x2. -Note: The matrix S = A31 A13 + A32 A23 + A33 - I is the "Schur complement" of -[-I 0; 0 -I] (the top-left 4 blocks) in A. -=# -# Function required by OrdinaryDiffEq.jl -linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve! -_linsolve!(x, A, b, update_matrix = false; kwargs...) = - LinearAlgebra.ldiv!(x, A, b) - -# Function required by Krylov.jl (x and b can be AbstractVectors) -# See https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/605 for a -# related issue that requires the same workaround. -function LinearAlgebra.ldiv!(x, A::SchurComplementW, b) - A.temp1 .= b - LinearAlgebra.ldiv!(A.temp2, A, A.temp1) - x .= A.temp2 -end - -function LinearAlgebra.ldiv!( - x::Fields.FieldVector, - A::SchurComplementW, - b::Fields.FieldVector, -) - (; dtγ_ref, ∂ᶜρₜ∂ᶠ𝕄, ∂ᶜ𝔼ₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶠ𝕄ₜ∂ᶠ𝕄) = A - (; S) = A - dtγ = dtγ_ref[] - - xᶜρ = x.c.ρ - bᶜρ = b.c.ρ - if :ρθ in propertynames(x.c) - xᶜ𝔼 = x.c.ρθ - bᶜ𝔼 = b.c.ρθ - elseif :ρe in propertynames(x.c) - xᶜ𝔼 = x.c.ρe - bᶜ𝔼 = b.c.ρe - elseif :ρe_int in propertynames(x.c) - xᶜ𝔼 = x.c.ρe_int - bᶜ𝔼 = b.c.ρe_int - end - if :ρw in propertynames(x.f) - xᶠ𝕄 = x.f.ρw.components.data.:1 - bᶠ𝕄 = b.f.ρw.components.data.:1 - elseif :w in propertynames(x.f) - xᶠ𝕄 = x.f.w.components.data.:1 - bᶠ𝕄 = b.f.w.components.data.:1 - end - - # TODO: Extend LinearAlgebra.I to work with stencil fields. - FT = eltype(eltype(S)) - I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT)))) - if Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) != (-half, half) - str = "The linear solver cannot yet be run with the given ∂ᶜ𝔼ₜ/∂ᶠ𝕄 \ - block, since it has more than 2 diagonals. So, ∂ᶜ𝔼ₜ/∂ᶠ𝕄 will \ - be set to 0 for the Schur complement computation. Consider \ - changing the ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode or the energy variable." - @warn str maxlog = 1 - @. S = dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I - else - @. S = - dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + - dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶜ𝔼ₜ∂ᶠ𝕄) + - dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I - end - - @. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼)) - - Operators.column_thomas_solve!(S, xᶠ𝕄) - - @. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄) - @. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄) - - if A.test && Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) == (-half, half) - Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xᶜρ))) - ∂Yₜ∂Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1) - ΔY = Array{FT}(undef, 3 * Nv + 1) - ΔΔY = Array{FT}(undef, 3 * Nv + 1) - for h in 1:Nh, j in 1:Nj, i in 1:Ni - ∂Yₜ∂Y .= zero(FT) - ∂Yₜ∂Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(∂ᶜρₜ∂ᶠ𝕄, axes(x.f), i, j, h) - ∂Yₜ∂Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(∂ᶜ𝔼ₜ∂ᶠ𝕄, axes(x.f), i, j, h) - ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .= - matrix_column(∂ᶠ𝕄ₜ∂ᶜρ, axes(x.c), i, j, h) - ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .= - matrix_column(∂ᶠ𝕄ₜ∂ᶜ𝔼, axes(x.c), i, j, h) - ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(∂ᶠ𝕄ₜ∂ᶠ𝕄, axes(x.f), i, j, h) - ΔY[1:Nv] .= vector_column(xᶜρ, i, j, h) - ΔY[(Nv + 1):(2 * Nv)] .= vector_column(xᶜ𝔼, i, j, h) - ΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xᶠ𝕄, i, j, h) - ΔΔY[1:Nv] .= vector_column(bᶜρ, i, j, h) - ΔΔY[(Nv + 1):(2 * Nv)] .= vector_column(bᶜ𝔼, i, j, h) - ΔΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bᶠ𝕄, i, j, h) - @assert (-LinearAlgebra.I + dtγ * ∂Yₜ∂Y) * ΔY ≈ ΔΔY - end - end - - if :ρuₕ in propertynames(x.c) - @. x.c.ρuₕ = -b.c.ρuₕ - elseif :uₕ in propertynames(x.c) - @. x.c.uₕ = -b.c.uₕ - end - - if A.transform - x .*= dtγ - end -end diff --git a/examples/hybrid/staggered_nonhydrostatic_model.jl b/examples/hybrid/staggered_nonhydrostatic_model.jl index 302de7696f..76d5be9093 100644 --- a/examples/hybrid/staggered_nonhydrostatic_model.jl +++ b/examples/hybrid/staggered_nonhydrostatic_model.jl @@ -1,8 +1,7 @@ -using LinearAlgebra: ×, norm, norm_sqr, dot - +using LinearAlgebra: ×, norm, norm_sqr, dot, Adjoint using ClimaCore: Operators, Fields -include("schur_complement_W.jl") +include("implicit_equation_jacobian.jl") include("hyperdiffusion.jl") # Constants required before `include("staggered_nonhydrostatic_model.jl")` @@ -22,6 +21,13 @@ const cp_d = R_d / κ # heat capacity at constant pressure const cv_d = cp_d - R_d # heat capacity at constant volume const γ = cp_d / cv_d # heat capacity ratio +const C3 = Geometry.Covariant3Vector +const C12 = Geometry.Covariant12Vector +const C123 = Geometry.Covariant123Vector +const CT1 = Geometry.Contravariant1Vector +const CT3 = Geometry.Contravariant3Vector +const CT12 = Geometry.Contravariant12Vector + const divₕ = Operators.Divergence() const wdivₕ = Operators.WeakDivergence() const gradₕ = Operators.Gradient() @@ -35,16 +41,16 @@ const ᶠinterp = Operators.InterpolateC2F( top = Operators.Extrapolate(), ) const ᶜdivᵥ = Operators.DivergenceF2C( - top = Operators.SetValue(Geometry.Contravariant3Vector(FT(0))), - bottom = Operators.SetValue(Geometry.Contravariant3Vector(FT(0))), + top = Operators.SetValue(CT3(FT(0))), + bottom = Operators.SetValue(CT3(FT(0))), ) const ᶠgradᵥ = Operators.GradientC2F( - bottom = Operators.SetGradient(Geometry.Covariant3Vector(FT(0))), - top = Operators.SetGradient(Geometry.Covariant3Vector(FT(0))), + bottom = Operators.SetGradient(C3(FT(0))), + top = Operators.SetGradient(C3(FT(0))), ) const ᶠcurlᵥ = Operators.CurlC2F( - bottom = Operators.SetCurl(Geometry.Contravariant12Vector(FT(0), FT(0))), - top = Operators.SetCurl(Geometry.Contravariant12Vector(FT(0), FT(0))), + bottom = Operators.SetCurl(CT12(FT(0), FT(0))), + top = Operators.SetCurl(CT12(FT(0), FT(0))), ) const ᶜFC = Operators.FluxCorrectionC2C( bottom = Operators.Extrapolate(), @@ -56,12 +62,25 @@ const ᶠupwind_product3 = Operators.Upwind3rdOrderBiasedProductC2F( top = Operators.ThirdOrderOneSided(), ) -const ᶜinterp_stencil = Operators.Operator2Stencil(ᶜinterp) -const ᶠinterp_stencil = Operators.Operator2Stencil(ᶠinterp) -const ᶜdivᵥ_stencil = Operators.Operator2Stencil(ᶜdivᵥ) -const ᶠgradᵥ_stencil = Operators.Operator2Stencil(ᶠgradᵥ) +const ᶜinterp_matrix = MatrixFields.operator_matrix(ᶜinterp) +const ᶠinterp_matrix = MatrixFields.operator_matrix(ᶠinterp) +const ᶜdivᵥ_matrix = MatrixFields.operator_matrix(ᶜdivᵥ) +const ᶠgradᵥ_matrix = MatrixFields.operator_matrix(ᶠgradᵥ) +const ᶠupwind_product1_matrix = MatrixFields.operator_matrix(ᶠupwind_product1) +const ᶠupwind_product3_matrix = MatrixFields.operator_matrix(ᶠupwind_product3) -const C123 = Geometry.Covariant123Vector +const ᶠno_flux = Operators.SetBoundaryOperator( + top = Operators.SetValue(CT3(FT(0))), + bottom = Operators.SetValue(CT3(FT(0))), +) +const ᶠno_flux_row1 = Operators.SetBoundaryOperator( + top = Operators.SetValue(zero(BidiagonalMatrixRow{CT3{FT}})), + bottom = Operators.SetValue(zero(BidiagonalMatrixRow{CT3{FT}})), +) +const ᶠno_flux_row3 = Operators.SetBoundaryOperator( + top = Operators.SetValue(zero(QuaddiagonalMatrixRow{CT3{FT}})), + bottom = Operators.SetValue(zero(QuaddiagonalMatrixRow{CT3{FT}})), +) pressure_ρθ(ρθ) = p_0 * (ρθ * R_d / p_0)^γ pressure_ρe(ρe, K, Φ, ρ) = ρ * R_d * ((ρe / ρ - K - Φ) / cv_d + T_tri) @@ -79,24 +98,32 @@ function default_cache(ᶜlocal_geometry, ᶠlocal_geometry, Y, upwinding_mode) else ᶜf = map(_ -> f, ᶜlocal_geometry) end - ᶜf = @. Geometry.Contravariant3Vector(Geometry.WVector(ᶜf)) + ᶜf = @. CT3(Geometry.WVector(ᶜf)) + ᶠupwind_product, ᶠupwind_product_matrix, ᶠno_flux_row = + if upwinding_mode == :first_order + ᶠupwind_product1, ᶠupwind_product1_matrix, ᶠno_flux_row1 + elseif upwinding_mode == :third_order + ᶠupwind_product3, ᶠupwind_product3_matrix, ᶠno_flux_row3 + else + nothing, nothing, nothing + end return (; - ᶜuvw = similar(ᶜlocal_geometry, Geometry.Covariant123Vector{FT}), + ᶜuvw = similar(ᶜlocal_geometry, C123{FT}), ᶜK = similar(ᶜlocal_geometry, FT), ᶜΦ = grav .* ᶜcoord.z, ᶜp = similar(ᶜlocal_geometry, FT), - ᶜω³ = similar(ᶜlocal_geometry, Geometry.Contravariant3Vector{FT}), - ᶠω¹² = similar(ᶠlocal_geometry, Geometry.Contravariant12Vector{FT}), - ᶠu¹² = similar(ᶠlocal_geometry, Geometry.Contravariant12Vector{FT}), - ᶠu³ = similar(ᶠlocal_geometry, Geometry.Contravariant3Vector{FT}), + ᶜω³ = similar(ᶜlocal_geometry, CT3{FT}), + ᶠω¹² = similar(ᶠlocal_geometry, CT12{FT}), + ᶠu¹² = similar(ᶠlocal_geometry, CT12{FT}), + ᶠu³ = similar(ᶠlocal_geometry, CT3{FT}), ᶜf, - ∂ᶜK∂ᶠw_data = similar( + ∂ᶜK∂ᶠw = similar( ᶜlocal_geometry, - Operators.StencilCoefs{-half, half, NTuple{2, FT}}, + BidiagonalMatrixRow{Adjoint{FT, CT3{FT}}}, ), - ᶠupwind_product = upwinding_mode == :first_order ? ᶠupwind_product1 : - upwinding_mode == :third_order ? ᶠupwind_product3 : - nothing, + ᶠupwind_product, + ᶠupwind_product_matrix, + ᶠno_flux_row, ghost_buffer = ( c = Spaces.create_dss_buffer(Y.c), f = Spaces.create_dss_buffer(Y.f), @@ -115,14 +142,6 @@ function implicit_tendency!(Yₜ, Y, p, t) ᶠw = Y.f.w (; ᶜK, ᶜΦ, ᶜp, ᶠupwind_product) = p - # Used for automatically computing the Jacobian ∂Yₜ/∂Y. Currently requires - # allocation because the cache is stored separately from Y, which means that - # similar(Y, <:Dual) doesn't allocate an appropriate cache for computing Yₜ. - if eltype(Y) <: Dual - ᶜK = similar(ᶜρ) - ᶜp = similar(ᶜρ) - end - @. ᶜK = norm_sqr(C123(ᶜuₕ) + C123(ᶜinterp(ᶠw))) / 2 @. Yₜ.c.ρ = -(ᶜdivᵥ(ᶠinterp(ᶜρ) * ᶠw)) @@ -152,7 +171,7 @@ function implicit_tendency!(Yₜ, Y, p, t) if isnothing(ᶠupwind_product) @. Yₜ.c.ρe_int = -( ᶜdivᵥ(ᶠinterp(ᶜρe_int + ᶜp) * ᶠw) - - ᶜinterp(dot(ᶠgradᵥ(ᶜp), Geometry.Contravariant3Vector(ᶠw))) + ᶜinterp(dot(ᶠgradᵥ(ᶜp), CT3(ᶠw))) ) # or, equivalently, # Yₜ.c.ρe_int = -(ᶜdivᵥ(ᶠinterp(ᶜρe_int) * ᶠw) + ᶜp * ᶜdivᵥ(ᶠw)) @@ -161,13 +180,12 @@ function implicit_tendency!(Yₜ, Y, p, t) ᶜdivᵥ( ᶠinterp(Y.c.ρ) * ᶠupwind_product(ᶠw, (ᶜρe_int + ᶜp) / Y.c.ρ), - ) - - ᶜinterp(dot(ᶠgradᵥ(ᶜp), Geometry.Contravariant3Vector(ᶠw))) + ) - ᶜinterp(dot(ᶠgradᵥ(ᶜp), CT3(ᶠw))) ) end end - Yₜ.c.uₕ .= Ref(zero(eltype(Yₜ.c.uₕ))) + Yₜ.c.uₕ .= (zero(eltype(Yₜ.c.uₕ)),) @. Yₜ.f.w = -(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) + ᶠgradᵥ(ᶜK + ᶜΦ)) @@ -229,12 +247,10 @@ function default_remaining_tendency!(Yₜ, Y, p, t) @. ᶜp = pressure_ρe_int(ᶜρe_int, ᶜρ) if point_type <: Geometry.Abstract3DPoint @. Yₜ.c.ρe_int -= - divₕ((ᶜρe_int + ᶜp) * ᶜuvw) - - dot(gradₕ(ᶜp), Geometry.Contravariant12Vector(ᶜuₕ)) + divₕ((ᶜρe_int + ᶜp) * ᶜuvw) - dot(gradₕ(ᶜp), CT12(ᶜuₕ)) else @. Yₜ.c.ρe_int -= - divₕ((ᶜρe_int + ᶜp) * ᶜuvw) - - dot(gradₕ(ᶜp), Geometry.Contravariant1Vector(ᶜuₕ)) + divₕ((ᶜρe_int + ᶜp) * ᶜuvw) - dot(gradₕ(ᶜp), CT1(ᶜuₕ)) end @. Yₜ.c.ρe_int -= ᶜdivᵥ(ᶠinterp((ᶜρe_int + ᶜp) * ᶜuₕ)) # or, equivalently, @@ -249,22 +265,20 @@ function default_remaining_tendency!(Yₜ, Y, p, t) @. ᶜω³ = curlₕ(ᶜuₕ) @. ᶠω¹² = curlₕ(ᶠw) elseif point_type <: Geometry.Abstract2DPoint - ᶜω³ .= Ref(zero(eltype(ᶜω³))) - @. ᶠω¹² = Geometry.Contravariant12Vector(curlₕ(ᶠw)) + ᶜω³ .= (zero(eltype(ᶜω³)),) + @. ᶠω¹² = CT12(curlₕ(ᶠw)) end @. ᶠω¹² += ᶠcurlᵥ(ᶜuₕ) # TODO: Modify to account for topography - @. ᶠu¹² = Geometry.Contravariant12Vector(ᶠinterp(ᶜuₕ)) - @. ᶠu³ = Geometry.Contravariant3Vector(ᶠw) + @. ᶠu¹² = CT12(ᶠinterp(ᶜuₕ)) + @. ᶠu³ = CT3(ᶠw) - @. Yₜ.c.uₕ -= - ᶜinterp(ᶠω¹² × ᶠu³) + (ᶜf + ᶜω³) × Geometry.Contravariant12Vector(ᶜuₕ) + @. Yₜ.c.uₕ -= ᶜinterp(ᶠω¹² × ᶠu³) + (ᶜf + ᶜω³) × CT12(ᶜuₕ) if point_type <: Geometry.Abstract3DPoint @. Yₜ.c.uₕ -= gradₕ(ᶜp) / ᶜρ + gradₕ(ᶜK + ᶜΦ) elseif point_type <: Geometry.Abstract2DPoint - @. Yₜ.c.uₕ -= - Geometry.Covariant12Vector(gradₕ(ᶜp) / ᶜρ + gradₕ(ᶜK + ᶜΦ)) + @. Yₜ.c.uₕ -= C12(gradₕ(ᶜp) / ᶜρ + gradₕ(ᶜK + ᶜΦ)) end @. Yₜ.f.w -= ᶠω¹² × ᶠu¹² @@ -273,46 +287,52 @@ end additional_tendency!(Yₜ, Y, p, t) = nothing function Wfact!(W, Y, p, dtγ, t) - (; flags, dtγ_ref, ∂ᶜρₜ∂ᶠ𝕄, ∂ᶜ𝔼ₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶠ𝕄ₜ∂ᶠ𝕄) = W + (; ∂Yₜ∂Y, ∂R∂Y, transform, flags) = W ᶜρ = Y.c.ρ ᶜuₕ = Y.c.uₕ ᶠw = Y.f.w - (; ᶜK, ᶜΦ, ᶜp, ∂ᶜK∂ᶠw_data, ᶠupwind_product) = p - - dtγ_ref[] = dtγ - - # If we let ᶠw_data = ᶠw.components.data.:1 and ᶠw_unit = one.(ᶠw), then - # ᶠw == ᶠw_data .* ᶠw_unit. The Jacobian blocks involve ᶠw_data, not ᶠw. - ᶠw_data = ᶠw.components.data.:1 - - # If ∂(ᶜarg)/∂(ᶠw_data) = 0, then - # ∂(ᶠupwind_product(ᶠw, ᶜarg))/∂(ᶠw_data) = - # ᶠupwind_product(ᶠw + εw, arg) / to_scalar(ᶠw + εw). - # The εw is only necessary in case w = 0. - εw = Ref(Geometry.Covariant3Vector(eps(FT))) - to_scalar(vector) = vector.u₃ - - # ᶜinterp(ᶠw) = - # ᶜinterp(ᶠw)_data * ᶜinterp(ᶠw)_unit = - # ᶜinterp(ᶠw_data) * ᶜinterp(ᶠw)_unit - # norm_sqr(ᶜinterp(ᶠw)) = - # norm_sqr(ᶜinterp(ᶠw_data) * ᶜinterp(ᶠw)_unit) = - # ᶜinterp(ᶠw_data)^2 * norm_sqr(ᶜinterp(ᶠw)_unit) + (; ᶜK, ᶜΦ, ᶜp, ∂ᶜK∂ᶠw) = p + (; ᶠupwind_product, ᶠupwind_product_matrix, ᶠno_flux_row) = p + + ᶜρ_name = @name(c.ρ) + ᶜ𝔼_name = if :ρθ in propertynames(Y.c) + @name(c.ρθ) + elseif :ρe in propertynames(Y.c) + @name(c.ρe) + elseif :ρe_int in propertynames(Y.c) + @name(c.ρe_int) + end + ᶠ𝕄_name = @name(f.w) + ∂ᶜρₜ∂ᶠ𝕄 = ∂Yₜ∂Y[ᶜρ_name, ᶠ𝕄_name] + ∂ᶜ𝔼ₜ∂ᶠ𝕄 = ∂Yₜ∂Y[ᶜ𝔼_name, ᶠ𝕄_name] + ∂ᶠ𝕄ₜ∂ᶜρ = ∂Yₜ∂Y[ᶠ𝕄_name, ᶜρ_name] + ∂ᶠ𝕄ₜ∂ᶜ𝔼 = ∂Yₜ∂Y[ᶠ𝕄_name, ᶜ𝔼_name] + ∂ᶠ𝕄ₜ∂ᶠ𝕄 = ∂Yₜ∂Y[ᶠ𝕄_name, ᶠ𝕄_name] + + ᶠgⁱʲ = Fields.local_geometry_field(ᶠw).gⁱʲ + g³³(gⁱʲ) = Geometry.AxisTensor( + (Geometry.Contravariant3Axis(), Geometry.Contravariant3Axis()), + Geometry.components(gⁱʲ)[end], + ) + + # If ∂(ᶜχ)/∂(ᶠw) = 0, then + # ∂(ᶠupwind_product(ᶠw, ᶜχ))/∂(ᶠw) = + # ∂(ᶠupwind_product(ᶠw, ᶜχ))/∂(CT3(ᶠw)) * ∂(CT3(ᶠw))/∂(ᶠw) = + # vec_data(ᶠupwind_product(ᶠw + εw, ᶜχ)) / vec_data(CT3(ᶠw + εw)) * ᶠg³³ + # The ε is only necessary when w = 0. Since ᶠupwind_product is undefined at + # the boundaries, we also need to wrap it in a call to ᶠno_flux. + vec_data(vector) = vector[1] + εw = (C3(eps(FT)),) + # ᶜK = # norm_sqr(C123(ᶜuₕ) + C123(ᶜinterp(ᶠw))) / 2 = - # norm_sqr(ᶜuₕ) / 2 + norm_sqr(ᶜinterp(ᶠw)) / 2 = - # norm_sqr(ᶜuₕ) / 2 + ᶜinterp(ᶠw_data)^2 * norm_sqr(ᶜinterp(ᶠw)_unit) / 2 - # ∂(ᶜK)/∂(ᶠw_data) = - # ∂(ᶜK)/∂(ᶜinterp(ᶠw_data)) * ∂(ᶜinterp(ᶠw_data))/∂(ᶠw_data) = - # ᶜinterp(ᶠw_data) * norm_sqr(ᶜinterp(ᶠw)_unit) * ᶜinterp_stencil(1) - @. ∂ᶜK∂ᶠw_data = - ᶜinterp(ᶠw_data) * - norm_sqr(one(ᶜinterp(ᶠw))) * - ᶜinterp_stencil(one(ᶠw_data)) + # ACT12(ᶜuₕ) * ᶜuₕ / 2 + ACT3(ᶜinterp(ᶠw)) * ᶜinterp(ᶠw) / 2 + # ∂(ᶜK)/∂(ᶠw) = ACT3(ᶜinterp(ᶠw)) * ᶜinterp_matrix() + @. ∂ᶜK∂ᶠw = DiagonalMatrixRow(adjoint(CT3(ᶜinterp(ᶠw)))) ⋅ ᶜinterp_matrix() # ᶜρₜ = -ᶜdivᵥ(ᶠinterp(ᶜρ) * ᶠw) - # ∂(ᶜρₜ)/∂(ᶠw_data) = -ᶜdivᵥ_stencil(ᶠinterp(ᶜρ) * ᶠw_unit) - @. ∂ᶜρₜ∂ᶠ𝕄 = -(ᶜdivᵥ_stencil(ᶠinterp(ᶜρ) * one(ᶠw))) + # ∂(ᶜρₜ)/∂(ᶠw) = -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρ) * ᶠg³³ + @. ∂ᶜρₜ∂ᶠ𝕄 = -(ᶜdivᵥ_matrix()) ⋅ DiagonalMatrixRow(ᶠinterp(ᶜρ) * g³³(ᶠgⁱʲ)) if :ρθ in propertynames(Y.c) ᶜρθ = Y.c.ρθ @@ -324,72 +344,85 @@ function Wfact!(W, Y, p, dtγ, t) if isnothing(ᶠupwind_product) # ᶜρθₜ = -ᶜdivᵥ(ᶠinterp(ᶜρθ) * ᶠw) - # ∂(ᶜρθₜ)/∂(ᶠw_data) = -ᶜdivᵥ_stencil(ᶠinterp(ᶜρθ) * ᶠw_unit) - @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = -(ᶜdivᵥ_stencil(ᶠinterp(ᶜρθ) * one(ᶠw))) + # ∂(ᶜρθₜ)/∂(ᶠw) = -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρθ) * ᶠg³³ + @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = + -(ᶜdivᵥ_matrix()) ⋅ DiagonalMatrixRow(ᶠinterp(ᶜρθ) * g³³(ᶠgⁱʲ)) else # ᶜρθₜ = -ᶜdivᵥ(ᶠinterp(ᶜρ) * ᶠupwind_product(ᶠw, ᶜρθ / ᶜρ)) - # ∂(ᶜρθₜ)/∂(ᶠw_data) = - # -ᶜdivᵥ_stencil( - # ᶠinterp(ᶜρ) * ∂(ᶠupwind_product(ᶠw, ᶜρθ / ᶜρ))/∂(ᶠw_data), - # ) - @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = -(ᶜdivᵥ_stencil( - ᶠinterp(ᶜρ) * ᶠupwind_product(ᶠw + εw, ᶜρθ / ᶜρ) / - to_scalar(ᶠw + εw), - )) + # ∂(ᶜρθₜ)/∂(ᶠw) = + # -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρ) * + # ∂(ᶠupwind_product(ᶠw, ᶜρθ / ᶜρ))/∂(ᶠw) + @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = + -(ᶜdivᵥ_matrix()) ⋅ DiagonalMatrixRow( + ᶠinterp(ᶜρ) * + vec_data(ᶠno_flux(ᶠupwind_product(ᶠw + εw, ᶜρθ / ᶜρ))) / + vec_data(CT3(ᶠw + εw)) * g³³(ᶠgⁱʲ), + ) end elseif :ρe in propertynames(Y.c) ᶜρe = Y.c.ρe @. ᶜK = norm_sqr(C123(ᶜuₕ) + C123(ᶜinterp(ᶠw))) / 2 @. ᶜp = pressure_ρe(ᶜρe, ᶜK, ᶜΦ, ᶜρ) - if isnothing(ᶠupwind_product) - if flags.∂ᶜ𝔼ₜ∂ᶠ𝕄_mode == :exact + if flags.∂ᶜ𝔼ₜ∂ᶠ𝕄_mode == :exact + if isnothing(ᶠupwind_product) # ᶜρeₜ = -ᶜdivᵥ(ᶠinterp(ᶜρe + ᶜp) * ᶠw) - # ∂(ᶜρeₜ)/∂(ᶠw_data) = - # -ᶜdivᵥ_stencil(ᶠinterp(ᶜρe + ᶜp) * ᶠw_unit) - - # ᶜdivᵥ_stencil(ᶠw) * ∂(ᶠinterp(ᶜρe + ᶜp))/∂(ᶠw_data) - # ∂(ᶠinterp(ᶜρe + ᶜp))/∂(ᶠw_data) = - # ∂(ᶠinterp(ᶜρe + ᶜp))/∂(ᶜp) * ∂(ᶜp)/∂(ᶠw_data) - # ∂(ᶠinterp(ᶜρe + ᶜp))/∂(ᶜp) = ᶠinterp_stencil(1) - # ∂(ᶜp)/∂(ᶠw_data) = ∂(ᶜp)/∂(ᶜK) * ∂(ᶜK)/∂(ᶠw_data) + # ∂(ᶜρeₜ)/∂(ᶠw) = + # -ᶜdivᵥ_matrix() * ( + # ᶠinterp(ᶜρe + ᶜp) * ᶠg³³ + + # CT3(ᶠw) * ∂(ᶠinterp(ᶜρe + ᶜp))/∂(ᶠw) + # ) + # ∂(ᶠinterp(ᶜρe + ᶜp))/∂(ᶠw) = + # ∂(ᶠinterp(ᶜρe + ᶜp))/∂(ᶜp) * ∂(ᶜp)/∂(ᶠw) + # ∂(ᶠinterp(ᶜρe + ᶜp))/∂(ᶜp) = ᶠinterp_matrix() + # ∂(ᶜp)/∂(ᶠw) = ∂(ᶜp)/∂(ᶜK) * ∂(ᶜK)/∂(ᶠw) # ∂(ᶜp)/∂(ᶜK) = -ᶜρ * R_d / cv_d @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = - -(ᶜdivᵥ_stencil(ᶠinterp(ᶜρe + ᶜp) * one(ᶠw))) - compose( - ᶜdivᵥ_stencil(ᶠw), - compose( - ᶠinterp_stencil(one(ᶜp)), - -(ᶜρ * R_d / cv_d) * ∂ᶜK∂ᶠw_data, - ), + -(ᶜdivᵥ_matrix()) ⋅ ( + DiagonalMatrixRow(ᶠinterp(ᶜρe + ᶜp) * g³³(ᶠgⁱʲ)) + + DiagonalMatrixRow(CT3(ᶠw)) ⋅ ᶠinterp_matrix() ⋅ + DiagonalMatrixRow(-(ᶜρ * R_d / cv_d)) ⋅ ∂ᶜK∂ᶠw ) - elseif flags.∂ᶜ𝔼ₜ∂ᶠ𝕄_mode == :no_∂ᶜp∂ᶜK - # same as above, but we approximate ∂(ᶜp)/∂(ᶜK) = 0, so that - # ∂ᶜ𝔼ₜ∂ᶠ𝕄 has 3 diagonals instead of 5 - @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = -(ᶜdivᵥ_stencil(ᶠinterp(ᶜρe + ᶜp) * one(ᶠw))) else - error( - "∂ᶜ𝔼ₜ∂ᶠ𝕄_mode must be :exact or :no_∂ᶜp∂ᶜK when using ρe \ - without upwinding", - ) - end - else - # TODO: Add Operator2Stencil for UpwindBiasedProductC2F to ClimaCore - # to allow exact Jacobian calculation. - if flags.∂ᶜ𝔼ₜ∂ᶠ𝕄_mode == :no_∂ᶜp∂ᶜK # ᶜρeₜ = # -ᶜdivᵥ(ᶠinterp(ᶜρ) * ᶠupwind_product(ᶠw, (ᶜρe + ᶜp) / ᶜρ)) - # ∂(ᶜρeₜ)/∂(ᶠw_data) = - # -ᶜdivᵥ_stencil( - # ᶠinterp(ᶜρ) * - # ∂(ᶠupwind_product(ᶠw, (ᶜρe + ᶜp) / ᶜρ))/∂(ᶠw_data), - # ) - @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = -(ᶜdivᵥ_stencil( - ᶠinterp(ᶜρ) * ᶠupwind_product(ᶠw + εw, (ᶜρe + ᶜp) / ᶜρ) / - to_scalar(ᶠw + εw), - )) + # ∂(ᶜρeₜ)/∂(ᶠw) = + # -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρ) * ( + # ∂(ᶠupwind_product(ᶠw, (ᶜρe + ᶜp) / ᶜρ))/∂(ᶠw) + + # ᶠupwind_product_matrix(ᶠw) * ∂((ᶜρe + ᶜp) / ᶜρ)/∂(ᶠw) + # ∂((ᶜρe + ᶜp) / ᶜρ)/∂(ᶠw) = 1 / ᶜρ * ∂(ᶜp)/∂(ᶠw) + # ∂(ᶜp)/∂(ᶠw) = ∂(ᶜp)/∂(ᶜK) * ∂(ᶜK)/∂(ᶠw) + # ∂(ᶜp)/∂(ᶜK) = -ᶜρ * R_d / cv_d + @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = + -(ᶜdivᵥ_matrix()) ⋅ DiagonalMatrixRow(ᶠinterp(ᶜρ)) ⋅ ( + DiagonalMatrixRow( + vec_data( + ᶠno_flux( + ᶠupwind_product(ᶠw + εw, (ᶜρe + ᶜp) / ᶜρ), + ), + ) / vec_data(CT3(ᶠw + εw)) * g³³(ᶠgⁱʲ), + ) + + ᶠno_flux_row(ᶠupwind_product_matrix(ᶠw)) ⋅ + (-R_d / cv_d * ∂ᶜK∂ᶠw) + ) + end + elseif flags.∂ᶜ𝔼ₜ∂ᶠ𝕄_mode == :no_∂ᶜp∂ᶜK + # same as above, but we approximate ∂(ᶜp)/∂(ᶜK) = 0, so that + # ∂ᶜ𝔼ₜ∂ᶠ𝕄 has 3 diagonals instead of 5 + if isnothing(ᶠupwind_product) + @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = + -(ᶜdivᵥ_matrix()) ⋅ + DiagonalMatrixRow(ᶠinterp(ᶜρe + ᶜp) * g³³(ᶠgⁱʲ)) else - error("∂ᶜ𝔼ₜ∂ᶠ𝕄_mode must be :no_∂ᶜp∂ᶜK when using ρe with \ - upwinding") + @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = + -(ᶜdivᵥ_matrix()) ⋅ DiagonalMatrixRow( + ᶠinterp(ᶜρ) * vec_data( + ᶠno_flux(ᶠupwind_product(ᶠw + εw, (ᶜρe + ᶜp) / ᶜρ)), + ) / vec_data(CT3(ᶠw + εw)) * g³³(ᶠgⁱʲ), + ) end + else + error("∂ᶜ𝔼ₜ∂ᶠ𝕄_mode must be :exact or :no_∂ᶜp∂ᶜK when using ρe") end elseif :ρe_int in propertynames(Y.c) ᶜρe_int = Y.c.ρe_int @@ -401,102 +434,72 @@ function Wfact!(W, Y, p, dtγ, t) if isnothing(ᶠupwind_product) # ᶜρe_intₜ = - # -( - # ᶜdivᵥ(ᶠinterp(ᶜρe_int + ᶜp) * ᶠw) - - # ᶜinterp(dot(ᶠgradᵥ(ᶜp), Geometry.Contravariant3Vector(ᶠw)) - # ) - # ∂(ᶜρe_intₜ)/∂(ᶠw_data) = - # -( - # ᶜdivᵥ_stencil(ᶠinterp(ᶜρe_int + ᶜp) * ᶠw_unit) - - # ᶜinterp_stencil(dot( - # ᶠgradᵥ(ᶜp), - # Geometry.Contravariant3Vector(ᶠw_unit), - # ),) - # ) - @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = -( - ᶜdivᵥ_stencil(ᶠinterp(ᶜρe_int + ᶜp) * one(ᶠw)) - - ᶜinterp_stencil( - dot(ᶠgradᵥ(ᶜp), Geometry.Contravariant3Vector(one(ᶠw))), - ) - ) + # -ᶜdivᵥ(ᶠinterp(ᶜρe_int + ᶜp) * ᶠw) + + # ᶜinterp(adjoint(ᶠgradᵥ(ᶜp)) * CT3(ᶠw)) + # ∂(ᶜρe_intₜ)/∂(ᶠw) = + # -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρe_int + ᶜp) * ᶠg³³ + + # ᶜinterp_matrix() * adjoint(ᶠgradᵥ(ᶜp)) * ᶠg³³ + @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = + -(ᶜdivᵥ_matrix()) ⋅ + DiagonalMatrixRow(ᶠinterp(ᶜρe_int + ᶜp) * g³³(ᶠgⁱʲ)) + + ᶜinterp_matrix() ⋅ + DiagonalMatrixRow(adjoint(ᶠgradᵥ(ᶜp)) * g³³(ᶠgⁱʲ)) else # ᶜρe_intₜ = - # -( - # ᶜdivᵥ( - # ᶠinterp(ᶜρ) * ᶠupwind_product(ᶠw, (ᶜρe_int + ᶜp) / ᶜρ), - # ) - - # ᶜinterp(dot(ᶠgradᵥ(ᶜp), Geometry.Contravariant3Vector(ᶠw))) - # ) - # ∂(ᶜρe_intₜ)/∂(ᶠw_data) = - # -( - # ᶜdivᵥ_stencil( - # ᶠinterp(ᶜρ) * - # ∂(ᶠupwind_product(ᶠw, (ᶜρe_int + ᶜp) / ᶜρ))/∂(ᶠw_data), - # ) - - # ᶜinterp_stencil(dot( - # ᶠgradᵥ(ᶜp), - # Geometry.Contravariant3Vector(ᶠw_unit), - # ),) - # ) - @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = -( - ᶜdivᵥ_stencil( - ᶠinterp(ᶜρ) * - ᶠupwind_product(ᶠw + εw, (ᶜρe_int + ᶜp) / ᶜρ) / - to_scalar(ᶠw + εw), - ) - ᶜinterp_stencil( - dot(ᶠgradᵥ(ᶜp), Geometry.Contravariant3Vector(one(ᶠw))), - ) - ) + # -ᶜdivᵥ(ᶠinterp(ᶜρ) * ᶠupwind_product(ᶠw, (ᶜρe_int + ᶜp) / ᶜρ)) + + # ᶜinterp(adjoint(ᶠgradᵥ(ᶜp)) * CT3(ᶠw)) + # ∂(ᶜρe_intₜ)/∂(ᶠw) = + # -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρ) * + # ∂(ᶠupwind_product(ᶠw, (ᶜρe_int + ᶜp) / ᶜρ))/∂(ᶠw) + + # ᶜinterp_matrix() * adjoint(ᶠgradᵥ(ᶜp)) * ᶠg³³ + @. ∂ᶜ𝔼ₜ∂ᶠ𝕄 = + -(ᶜdivᵥ_matrix()) ⋅ DiagonalMatrixRow( + ᶠinterp(ᶜρ) * vec_data( + ᶠno_flux(ᶠupwind_product(ᶠw + εw, (ᶜρe_int + ᶜp) / ᶜρ)), + ) / vec_data(CT3(ᶠw + εw)) * g³³(ᶠgⁱʲ), + ) + + ᶜinterp_matrix() ⋅ + DiagonalMatrixRow(adjoint(ᶠgradᵥ(ᶜp)) * g³³(ᶠgⁱʲ)) end end - # To convert ∂(ᶠwₜ)/∂(ᶜ𝔼) to ∂(ᶠw_data)ₜ/∂(ᶜ𝔼) and ∂(ᶠwₜ)/∂(ᶠw_data) to - # ∂(ᶠw_data)ₜ/∂(ᶠw_data), we must extract the third component of each - # vector-valued stencil coefficient. - to_scalar_coefs(vector_coefs) = - map(vector_coef -> vector_coef.u₃, vector_coefs) - - # TODO: If we end up using :gradΦ_shenanigans, optimize it to - # `cached_stencil / ᶠinterp(ᶜρ)`. - if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode != :exact && flags.∂ᶠ𝕄ₜ∂ᶜρ_mode != :gradΦ_shenanigans - error("∂ᶠ𝕄ₜ∂ᶜρ_mode must be :exact or :gradΦ_shenanigans") + # TODO: As an optimization, we can rewrite ∂ᶠ𝕄ₜ∂ᶜ𝔼 as 1 / ᶠinterp(ᶜρ) * M, + # where M is a constant matrix field. When ∂ᶠ𝕄ₜ∂ᶜρ_mode is set to + # :hydrostatic_balance, we can also do the same for ∂ᶠ𝕄ₜ∂ᶜρ. + if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode != :exact && + flags.∂ᶠ𝕄ₜ∂ᶜρ_mode != :hydrostatic_balance + error("∂ᶠ𝕄ₜ∂ᶜρ_mode must be :exact or :hydrostatic_balance") end if :ρθ in propertynames(Y.c) # ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ) # ∂(ᶠwₜ)/∂(ᶜρθ) = ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρθ) # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) = -1 / ᶠinterp(ᶜρ) # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρθ) = - # ᶠgradᵥ_stencil(γ * R_d * (ᶜρθ * R_d / p_0)^(γ - 1)) - @. ∂ᶠ𝕄ₜ∂ᶜ𝔼 = to_scalar_coefs( - -1 / ᶠinterp(ᶜρ) * - ᶠgradᵥ_stencil(γ * R_d * (ᶜρθ * R_d / p_0)^(γ - 1)), - ) + # ᶠgradᵥ_matrix() * γ * R_d * (ᶜρθ * R_d / p_0)^(γ - 1) + @. ∂ᶠ𝕄ₜ∂ᶜ𝔼 = + -DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ⋅ ᶠgradᵥ_matrix() ⋅ + DiagonalMatrixRow(γ * R_d * (ᶜρθ * R_d / p_0)^(γ - 1)) if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :exact # ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ) # ∂(ᶠwₜ)/∂(ᶜρ) = ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) * ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) # ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) = ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2 - # ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_stencil(1) - @. ∂ᶠ𝕄ₜ∂ᶜρ = to_scalar_coefs( - ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2 * ᶠinterp_stencil(one(ᶜρ)), - ) - elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :gradΦ_shenanigans - # ᶠwₜ = ( - # -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ′) - - # ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ′) * ᶠinterp(ᶜρ) - # ), where ᶜρ′ = ᶜρ but we approximate ∂(ᶜρ′)/∂(ᶜρ) = 0 - @. ∂ᶠ𝕄ₜ∂ᶜρ = to_scalar_coefs( - -(ᶠgradᵥ(ᶜΦ)) / ᶠinterp(ᶜρ) * ᶠinterp_stencil(one(ᶜρ)), - ) + # ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_matrix() + @. ∂ᶠ𝕄ₜ∂ᶜρ = + DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) ⋅ ᶠinterp_matrix() + elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :hydrostatic_balance + # same as above, but we assume that ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) = + # -ᶠgradᵥ(ᶜΦ) + @. ∂ᶠ𝕄ₜ∂ᶜρ = + -DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) ⋅ ᶠinterp_matrix() end elseif :ρe in propertynames(Y.c) # ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ) # ∂(ᶠwₜ)/∂(ᶜρe) = ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe) # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) = -1 / ᶠinterp(ᶜρ) - # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe) = ᶠgradᵥ_stencil(R_d / cv_d) - @. ∂ᶠ𝕄ₜ∂ᶜ𝔼 = to_scalar_coefs( - -1 / ᶠinterp(ᶜρ) * ᶠgradᵥ_stencil(R_d / cv_d * one(ᶜρe)), - ) + # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe) = ᶠgradᵥ_matrix() * R_d / cv_d + @. ∂ᶠ𝕄ₜ∂ᶜ𝔼 = + -DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ⋅ (ᶠgradᵥ_matrix() * R_d / cv_d) if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :exact # ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ) @@ -505,34 +508,28 @@ function Wfact!(W, Y, p, dtγ, t) # ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) * ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) = -1 / ᶠinterp(ᶜρ) # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρ) = - # ᶠgradᵥ_stencil(R_d * (-(ᶜK + ᶜΦ) / cv_d + T_tri)) + # ᶠgradᵥ_matrix() * R_d * (-(ᶜK + ᶜΦ) / cv_d + T_tri) # ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) = ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2 - # ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_stencil(1) - @. ∂ᶠ𝕄ₜ∂ᶜρ = to_scalar_coefs( - -1 / ᶠinterp(ᶜρ) * - ᶠgradᵥ_stencil(R_d * (-(ᶜK + ᶜΦ) / cv_d + T_tri)) + - ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2 * ᶠinterp_stencil(one(ᶜρ)), - ) - elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :gradΦ_shenanigans - # ᶠwₜ = ( - # -ᶠgradᵥ(ᶜp′) / ᶠinterp(ᶜρ′) - - # ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ′) * ᶠinterp(ᶜρ) - # ), where ᶜρ′ = ᶜρ but we approximate ∂ᶜρ′/∂ᶜρ = 0, and where - # ᶜp′ = ᶜp but with ᶜK = 0 - @. ∂ᶠ𝕄ₜ∂ᶜρ = to_scalar_coefs( - -1 / ᶠinterp(ᶜρ) * - ᶠgradᵥ_stencil(R_d * (-(ᶜΦ) / cv_d + T_tri)) - - ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ) * ᶠinterp_stencil(one(ᶜρ)), - ) + # ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_matrix() + @. ∂ᶠ𝕄ₜ∂ᶜρ = + -DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ⋅ ᶠgradᵥ_matrix() ⋅ + DiagonalMatrixRow(R_d * (-(ᶜK + ᶜΦ) / cv_d + T_tri)) + + DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) ⋅ ᶠinterp_matrix() + elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :hydrostatic_balance + # same as above, but we assume that ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) = + # -ᶠgradᵥ(ᶜΦ) and that ᶜK is negligible compared ot ᶜΦ + @. ∂ᶠ𝕄ₜ∂ᶜρ = + -DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ⋅ ᶠgradᵥ_matrix() ⋅ + DiagonalMatrixRow(R_d * (-(ᶜΦ) / cv_d + T_tri)) - + DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) ⋅ ᶠinterp_matrix() end elseif :ρe_int in propertynames(Y.c) # ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ) # ∂(ᶠwₜ)/∂(ᶜρe_int) = ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe_int) # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) = -1 / ᶠinterp(ᶜρ) - # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe_int) = ᶠgradᵥ_stencil(R_d / cv_d) - @. ∂ᶠ𝕄ₜ∂ᶜ𝔼 = to_scalar_coefs( - -1 / ᶠinterp(ᶜρ) * ᶠgradᵥ_stencil(R_d / cv_d * one(ᶜρe_int)), - ) + # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe_int) = ᶠgradᵥ_matrix() * R_d / cv_d + @. ∂ᶠ𝕄ₜ∂ᶜ𝔼 = + DiagonalMatrixRow(-1 / ᶠinterp(ᶜρ)) ⋅ (ᶠgradᵥ_matrix() * R_d / cv_d) if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :exact # ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ) @@ -540,89 +537,50 @@ function Wfact!(W, Y, p, dtγ, t) # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρ) + # ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) * ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) = -1 / ᶠinterp(ᶜρ) - # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρ) = ᶠgradᵥ_stencil(R_d * T_tri) + # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρ) = ᶠgradᵥ_matrix() * R_d * T_tri # ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) = ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2 - # ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_stencil(1) - @. ∂ᶠ𝕄ₜ∂ᶜρ = to_scalar_coefs( - -1 / ᶠinterp(ᶜρ) * ᶠgradᵥ_stencil(R_d * T_tri * one(ᶜρe_int)) + - ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2 * ᶠinterp_stencil(one(ᶜρ)), - ) - elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :gradΦ_shenanigans - # ᶠwₜ = ( - # -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ′) - - # ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ′) * ᶠinterp(ᶜρ) - # ), where ᶜp′ = ᶜp but we approximate ∂ᶜρ′/∂ᶜρ = 0 - @. ∂ᶠ𝕄ₜ∂ᶜρ = to_scalar_coefs( - -1 / ᶠinterp(ᶜρ) * ᶠgradᵥ_stencil(R_d * T_tri * one(ᶜρe_int)) - - ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ) * ᶠinterp_stencil(one(ᶜρ)), - ) + # ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_matrix() + @. ∂ᶠ𝕄ₜ∂ᶜρ = + -DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ⋅ + (ᶠgradᵥ_matrix() * R_d * T_tri) + + DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) ⋅ ᶠinterp_matrix() + elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :hydrostatic_balance + # same as above, but we assume that ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) = + # -ᶠgradᵥ(ᶜΦ) + @. ∂ᶠ𝕄ₜ∂ᶜρ = + DiagonalMatrixRow(-1 / ᶠinterp(ᶜρ)) ⋅ + (ᶠgradᵥ_matrix() * R_d * T_tri) - + DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) ⋅ ᶠinterp_matrix() end end # ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ) - # ∂(ᶠwₜ)/∂(ᶠw_data) = - # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶠw_dataₜ) + - # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜK + ᶜΦ)) * ∂(ᶠgradᵥ(ᶜK + ᶜΦ))/∂(ᶠw_dataₜ) = + # ∂(ᶠwₜ)/∂(ᶠw) = + # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶠw) + + # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜK + ᶜΦ)) * ∂(ᶠgradᵥ(ᶜK + ᶜΦ))/∂(ᶠw) = # ( # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶜK) + # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜK + ᶜΦ)) * ∂(ᶠgradᵥ(ᶜK + ᶜΦ))/∂(ᶜK) - # ) * ∂(ᶜK)/∂(ᶠw_dataₜ) + # ) * ∂(ᶜK)/∂(ᶠw) # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) = -1 / ᶠinterp(ᶜρ) # ∂(ᶠgradᵥ(ᶜp))/∂(ᶜK) = - # ᶜ𝔼_name == :ρe ? ᶠgradᵥ_stencil(-ᶜρ * R_d / cv_d) : 0 + # ᶜ𝔼_name == :ρe ? ᶠgradᵥ_matrix() * (-ᶜρ * R_d / cv_d) : 0 # ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜK + ᶜΦ)) = -1 - # ∂(ᶠgradᵥ(ᶜK + ᶜΦ))/∂(ᶜK) = ᶠgradᵥ_stencil(1) - # ∂(ᶜK)/∂(ᶠw_data) = - # ᶜinterp(ᶠw_data) * norm_sqr(ᶜinterp(ᶠw)_unit) * ᶜinterp_stencil(1) + # ∂(ᶠgradᵥ(ᶜK + ᶜΦ))/∂(ᶜK) = ᶠgradᵥ_matrix() if :ρθ in propertynames(Y.c) || :ρe_int in propertynames(Y.c) - @. ∂ᶠ𝕄ₜ∂ᶠ𝕄 = - to_scalar_coefs(compose(-1 * ᶠgradᵥ_stencil(one(ᶜK)), ∂ᶜK∂ᶠw_data)) + @. ∂ᶠ𝕄ₜ∂ᶠ𝕄 = -(ᶠgradᵥ_matrix()) ⋅ ∂ᶜK∂ᶠw elseif :ρe in propertynames(Y.c) - @. ∂ᶠ𝕄ₜ∂ᶠ𝕄 = to_scalar_coefs( - compose( - -1 / ᶠinterp(ᶜρ) * ᶠgradᵥ_stencil(-(ᶜρ * R_d / cv_d)) + - -1 * ᶠgradᵥ_stencil(one(ᶜK)), - ∂ᶜK∂ᶠw_data, - ), - ) + @. ∂ᶠ𝕄ₜ∂ᶠ𝕄 = + -( + DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ⋅ ᶠgradᵥ_matrix() ⋅ + DiagonalMatrixRow(-(ᶜρ * R_d / cv_d)) + ᶠgradᵥ_matrix() + ) ⋅ ∂ᶜK∂ᶠw end - if W.test - # Checking every column takes too long, so just check one. - i, j, h = 1, 1, 1 - if :ρθ in propertynames(Y.c) - ᶜ𝔼_name = :ρθ - elseif :ρe in propertynames(Y.c) - ᶜ𝔼_name = :ρe - elseif :ρe_int in propertynames(Y.c) - ᶜ𝔼_name = :ρe_int - end - args = (implicit_tendency!, Y, p, t, i, j, h) - @assert matrix_column(∂ᶜρₜ∂ᶠ𝕄, axes(Y.f), i, j, h) == - exact_column_jacobian_block(args..., (:c, :ρ), (:f, :w)) - @assert matrix_column(∂ᶠ𝕄ₜ∂ᶜ𝔼, axes(Y.c), i, j, h) ≈ - exact_column_jacobian_block(args..., (:f, :w), (:c, ᶜ𝔼_name)) - @assert matrix_column(∂ᶠ𝕄ₜ∂ᶠ𝕄, axes(Y.f), i, j, h) ≈ - exact_column_jacobian_block(args..., (:f, :w), (:f, :w)) - ∂ᶜ𝔼ₜ∂ᶠ𝕄_approx = matrix_column(∂ᶜ𝔼ₜ∂ᶠ𝕄, axes(Y.f), i, j, h) - ∂ᶜ𝔼ₜ∂ᶠ𝕄_exact = - exact_column_jacobian_block(args..., (:c, ᶜ𝔼_name), (:f, :w)) - if flags.∂ᶜ𝔼ₜ∂ᶠ𝕄_mode == :exact - @assert ∂ᶜ𝔼ₜ∂ᶠ𝕄_approx ≈ ∂ᶜ𝔼ₜ∂ᶠ𝕄_exact - else - err = norm(∂ᶜ𝔼ₜ∂ᶠ𝕄_approx .- ∂ᶜ𝔼ₜ∂ᶠ𝕄_exact) / norm(∂ᶜ𝔼ₜ∂ᶠ𝕄_exact) - @assert err < 1e-6 - # Note: the highest value seen so far is ~3e-7 (only applies to ρe) - end - ∂ᶠ𝕄ₜ∂ᶜρ_approx = matrix_column(∂ᶠ𝕄ₜ∂ᶜρ, axes(Y.c), i, j, h) - ∂ᶠ𝕄ₜ∂ᶜρ_exact = exact_column_jacobian_block(args..., (:f, :w), (:c, :ρ)) - if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :exact - @assert ∂ᶠ𝕄ₜ∂ᶜρ_approx ≈ ∂ᶠ𝕄ₜ∂ᶜρ_exact - else - err = norm(∂ᶠ𝕄ₜ∂ᶜρ_approx .- ∂ᶠ𝕄ₜ∂ᶜρ_exact) / norm(∂ᶠ𝕄ₜ∂ᶜρ_exact) - @assert err < 0.03 - # Note: the highest value seen so far for ρe is ~0.01, and the - # highest value seen so far for ρθ is ~0.02 - end + I = one(∂R∂Y) + if transform + @. ∂R∂Y = I / FT(dtγ) - ∂Yₜ∂Y + else + @. ∂R∂Y = FT(dtγ) * ∂Yₜ∂Y - I end end diff --git a/examples/implicit_solver_debugging_tools.jl b/examples/implicit_solver_debugging_tools.jl deleted file mode 100644 index 2f33b80ece..0000000000 --- a/examples/implicit_solver_debugging_tools.jl +++ /dev/null @@ -1,56 +0,0 @@ -using ForwardDiff: Dual -using SparseArrays: spdiagm - -using ClimaCore: Spaces, Operators - -get_var(obj, ::Tuple{}) = obj -get_var(obj, tup::Tuple) = get_var(getproperty(obj, tup[1]), Base.tail(tup)) -function exact_column_jacobian_block( - implicit_tendency!, - Y, - p, - t, - i, - j, - h, - Yₜ_name, - Y_name, -) - T = eltype(Y) - Y_var = get_var(Y, Y_name) - Y_var_vert_space = Spaces.column(axes(Y_var), i, j, h) - bot_level = Operators.left_idx(Y_var_vert_space) - top_level = Operators.right_idx(Y_var_vert_space) - partials = ntuple(_ -> zero(T), top_level - bot_level + 1) - Yᴰ = Dual.(Y, partials...) - Yᴰ_var = get_var(Yᴰ, Y_name) - ith_ε(i) = Dual.(zero(T), Base.setindex(partials, one(T), i)...) - set_level_εs!(level) = - parent(Spaces.level(Yᴰ_var, level)) .+= ith_ε(level - bot_level + 1) - foreach(set_level_εs!, bot_level:top_level) - Yₜᴰ = similar(Yᴰ) - implicit_tendency!(Yₜᴰ, Yᴰ, p, t) - col = Spaces.column(get_var(Yₜᴰ, Yₜ_name), i, j, h) - return vcat(map(dual -> [dual.partials.values...]', parent(col))...) -end - -# Note: These only work for scalar stencils. -vector_column(arg, i, j, h) = parent(Spaces.column(arg, i, j, h)) -function matrix_column(stencil, stencil_input_space, i, j, h) - lbw, ubw = Operators.bandwidths(eltype(stencil)) - coefs_column = Spaces.column(stencil, i, j, h).coefs - row_space = axes(coefs_column) - lrow = Operators.left_idx(row_space) - rrow = Operators.right_idx(row_space) - num_rows = rrow - lrow + 1 - col_space = Spaces.column(stencil_input_space, i, j, h) - lcol = Operators.left_idx(col_space) - rcol = Operators.right_idx(col_space) - num_cols = rcol - lcol + 1 - diag_key_value(diag) = - (diag + lrow - lcol) => view( - parent(getproperty(coefs_column, diag - lbw + 1)), - (max(lrow, lcol - diag):min(rrow, rcol - diag)) .- (lrow - 1), - ) - return spdiagm(num_rows, num_cols, map(diag_key_value, lbw:ubw)...) -end diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index 5a304ec098..b9d2cf9dd5 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -331,6 +331,9 @@ Base.Broadcast.broadcasted(fs::AbstractFieldStyle, ::typeof(/), args...) = Base.Broadcast.broadcasted(fs::AbstractFieldStyle, ::typeof(muladd), args...) = Base.Broadcast.broadcasted(fs, RecursiveApply.rmuladd, args...) +Base.Broadcast.broadcasted(fs::AbstractFieldStyle, ::typeof(zero), arg) = + Base.Broadcast.broadcasted(fs, RecursiveApply.rzero, arg) + # Specialize handling of vector-based functions to automatically add LocalGeometry information function Base.Broadcast.broadcasted( fs::AbstractFieldStyle, diff --git a/src/MatrixFields/band_matrix_row.jl b/src/MatrixFields/band_matrix_row.jl index 2b2355f16e..0b795fc091 100644 --- a/src/MatrixFields/band_matrix_row.jl +++ b/src/MatrixFields/band_matrix_row.jl @@ -55,9 +55,6 @@ band_matrix_row_type(ld, ud, T) = BandMatrixRow{ld, ud - ld + 1, T} Base.eltype(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = T -Base.zero(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = - BandMatrixRow{ld}(ntuple(_ -> rzero(T), Val(bw))...) - Base.map(f::F, rows::BandMatrixRow...) where {F} = BandMatrixRow{lower_diagonal(rows)}( map(f, map(row -> row.entries, rows)...)..., @@ -143,8 +140,20 @@ Base.:*(value::Geometry.SingleValue, row::BandMatrixRow) = Base.:/(row::BandMatrixRow, value::Number) = map(entry -> rdiv(entry, value), row) -inv(row::DiagonalMatrixRow) = DiagonalMatrixRow(inv(row[0])) +Base.zero(row::BandMatrixRow) = zero(typeof(row)) +Base.zero(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = + BandMatrixRow{ld}(ntuple(_ -> rzero(T), Val(bw))...) + +Base.one(row::BandMatrixRow) = one(typeof(row)) +Base.one(::Type{DiagonalMatrixRow{T}}) where {T} = + DiagonalMatrixRow(rmap(one, T)) +Base.one(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = + ld isa PlusHalf ? + error("A non-square matrix does not have a corresponding identity matrix") : + one(DiagonalMatrixRow{T}) + +inv(row::DiagonalMatrixRow) = DiagonalMatrixRow(rmap(inv, row[0])) inv(::BandMatrixRow{ld, bw}) where {ld, bw} = error( - "The inverse of a matrix with $bw diagonals is (usually) a dense matrix, \ + "The inverse of a matrix with $bw diagonals is generally a dense matrix, \ so it cannot be represented using BandMatrixRows", ) diff --git a/src/MatrixFields/field_matrix_solver.jl b/src/MatrixFields/field_matrix_solver.jl index 0210080b7f..f91100ced5 100644 --- a/src/MatrixFields/field_matrix_solver.jl +++ b/src/MatrixFields/field_matrix_solver.jl @@ -253,7 +253,7 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b) end cheap_inv(_) = false -cheap_inv(::UniformScaling) = true +cheap_inv(::ScalingFieldMatrixEntry) = true cheap_inv(A::ColumnwiseBandMatrixField) = eltype(A) <: DiagonalMatrixRow NVTX.@annotate function run_field_matrix_solver!( @@ -271,7 +271,7 @@ NVTX.@annotate function run_field_matrix_solver!( # these kernels into one. However, `multiple_field_solve!` # launches threads horizontally, and loops vertically (which # is slow) to perform the solve. In some circumstances, - # when a vertical loop is not needed (e.g., UniformScaling) + # when a vertical loop is not needed (e.g., ScalingFieldMatrixEntry) # launching several kernels may be cheaper than launching one # slower kernel, so we first check for types that may lead to fast # kernels. diff --git a/src/MatrixFields/field_matrix_with_solver.jl b/src/MatrixFields/field_matrix_with_solver.jl index ffcddb29ae..b5284ab05b 100644 --- a/src/MatrixFields/field_matrix_with_solver.jl +++ b/src/MatrixFields/field_matrix_with_solver.jl @@ -20,9 +20,6 @@ FieldMatrixWithSolver( alg::FieldMatrixSolverAlgorithm = BlockDiagonalSolve(), ) = FieldMatrixWithSolver(A, FieldMatrixSolver(alg, A, b)) -# TODO: Find a simple way to make b an optional argument and add a method for -# Base.one(::FieldMatrixWithSolver). - Base.keys(A::FieldMatrixWithSolver) = keys(A.matrix) Base.values(A::FieldMatrixWithSolver) = values(A.matrix) @@ -41,10 +38,24 @@ Base.:(==)(A1::FieldMatrixWithSolver, A2::FieldMatrixWithSolver) = Base.similar(A::FieldMatrixWithSolver) = FieldMatrixWithSolver(similar(A.matrix), A.solver) +# Since zero(::FieldMatrix) retains the sparsity pattern of the original matrix +# while zeroing out all mutable entries, its linear solver is unchanged. Base.zero(A::FieldMatrixWithSolver) = FieldMatrixWithSolver(zero(A.matrix), A.solver) +# Since one(::FieldMatrix) is an identity matrix, it does not require a linear +# solver. The equation I * x == b can be solved directly, without calling ldiv. +# TODO: Find a simple way to construct a linear solver for the identity matrix. +Base.one(A::FieldMatrixWithSolver) = + FieldMatrixWithSolver(one(A.matrix), nothing) + +Base.Broadcast.broadcastable(A::FieldMatrixWithSolver) = A.matrix + +Base.Broadcast.materialize!(A::FieldMatrixWithSolver, matrix::FieldMatrix) = + Base.Broadcast.materialize!(A.matrix, matrix) + ldiv!(x::Fields.FieldVector, A::FieldMatrixWithSolver, b::Fields.FieldVector) = + isnothing(A.solver) ? error("FieldMatrixSolver is unavailable") : field_matrix_solve!(A.solver, x, A.matrix, b) mul!(b::Fields.FieldVector, A::FieldMatrixWithSolver, x::Fields.FieldVector) = diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl index fa0cfbb15d..d961f375a4 100644 --- a/src/MatrixFields/field_name.jl +++ b/src/MatrixFields/field_name.jl @@ -101,6 +101,18 @@ wrapped_prop_names(::Val{prop_names}) where {prop_names} = ( wrapped_prop_names(Val(Base.tail(prop_names)))..., ) +filtered_names(f::F, x) where {F} = filtered_child_names(f, x, @name()) +function filtered_child_names(f::F, x, name) where {F} + field = get_field(x, name) + f(field) && return (name,) + internal_names = top_level_names(field) + isempty(internal_names) && return () + tuples_of_child_names = unrolled_map(internal_names) do internal_name + filtered_child_names(f, x, append_internal_name(name, internal_name)) + end + return unrolled_flatten(tuples_of_child_names) +end + ################################################################################ """ @@ -174,6 +186,9 @@ if hasfield(Method, :recursion_relation) for m in methods(wrapped_prop_names) m.recursion_relation = dont_limit end + for m in methods(filtered_child_names) + m.recursion_relation = dont_limit + end for m in methods(subtree_at_name) m.recursion_relation = dont_limit end diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index 272a8ee24f..f5ad66642b 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -59,12 +59,28 @@ end const FieldVectorView = FieldNameDict{FieldName} const FieldMatrix = FieldNameDict{FieldNamePair} +const ScalingFieldMatrixEntry{T} = + Union{UniformScaling{T}, DiagonalMatrixRow{T}} + +scaling_value(entry::UniformScaling) = entry.λ +scaling_value(entry::DiagonalMatrixRow) = entry[0] + check_entry(_, _) = false check_entry(::Type{FieldName}, ::Fields.Field) = true -check_entry(::Type{FieldNamePair}, ::UniformScaling) = true +check_entry(::Type{FieldNamePair}, ::ScalingFieldMatrixEntry) = true check_entry(::Type{FieldNamePair}, ::ColumnwiseBandMatrixField) = true -check_entry(_, entry::Base.AbstractBroadcasted) = - Base.Broadcast.BroadcastStyle(typeof(entry)) isa Fields.AbstractFieldStyle + +is_field_broadcasted(bc) = + Base.Broadcast.BroadcastStyle(typeof(bc)) isa Fields.AbstractFieldStyle +check_entry(::Type{FieldName}, entry::Base.AbstractBroadcasted) = + is_field_broadcasted(entry) +check_entry(::Type{FieldNamePair}, entry::Base.AbstractBroadcasted) = + is_field_broadcasted(entry) # && eltype(entry) <: BandMatrixRow +# TODO: Adding the eltype check introduces JET failures to several FieldMatrix +# test cases in CI. We may to implement our own version of eltype to avoid this. + +is_diagonal_matrix_entry(::ScalingFieldMatrixEntry) = true +is_diagonal_matrix_entry(entry) = eltype(entry) <: DiagonalMatrixRow function Base.show(io::IO, dict::FieldNameDict) T = eltype(keys(dict)) @@ -74,13 +90,13 @@ function Base.show(io::IO, dict::FieldNameDict) if entry isa Fields.Field print(io, eltype(entry), "-valued Field:") Fields._show_compact_field(io, entry, " ", true) - elseif entry isa UniformScaling - if entry.λ == 1 + elseif entry isa ScalingFieldMatrixEntry + if scaling_value(entry) == 1 print(io, "I") - elseif entry.λ == -1 + elseif scaling_value(entry) == -1 print(io, "-I") else - print(io, "$(entry.λ) * I") + print(io, "$(scaling_value(entry)) * I") end else print(io, entry) @@ -122,7 +138,8 @@ function Base.getindex(dict::FieldNameDict, key) key in keys(dict) || throw(KeyError(key)) key′, entry′ = unrolled_findonly(pair -> is_child_value(key, pair[1]), pairs(dict)) - return get_internal_entry(entry′, get_internal_key(key, key′)) + internal_key = get_internal_key(key, key′) + return get_internal_entry(entry′, internal_key, KeyError(key)) end get_internal_key(child_name::FieldName, name::FieldName) = @@ -132,20 +149,25 @@ get_internal_key(child_name_pair::FieldNamePair, name_pair::FieldNamePair) = ( extract_internal_name(child_name_pair[2], name_pair[2]), ) -unsupported_internal_entry_error(entry, key) = - error("Unsupported FieldNameDict operation: \ - get_internal_entry(<$(typeof(entry).name.name)>, $key)") - -get_internal_entry(entry, name::FieldName) = get_field(entry, name) -get_internal_entry(entry, name_pair::FieldNamePair) = - name_pair == (@name(), @name()) ? entry : - unsupported_internal_entry_error(entry, name_pair) -get_internal_entry(entry::UniformScaling, name_pair::FieldNamePair) = - name_pair[1] == name_pair[2] ? entry : - unsupported_internal_entry_error(entry, name_pair) +get_internal_entry(entry, name::FieldName, key_error) = get_field(entry, name) +get_internal_entry(entry, name_pair::FieldNamePair, key_error) = + name_pair == (@name(), @name()) ? entry : throw(key_error) +get_internal_entry( + entry::ScalingFieldMatrixEntry, + name_pair::FieldNamePair, + key_error, +) = + if name_pair[1] == name_pair[2] + entry + elseif is_overlapping_name(name_pair[1], name_pair[2]) + throw(key_error) + else + zero(entry) + end function get_internal_entry( entry::ColumnwiseBandMatrixField, name_pair::FieldNamePair, + key_error, ) # Ensure compatibility with RecursiveApply (i.e., with rmul). # See note above matrix_product_keys in field_name_set.jl for more details. @@ -165,7 +187,7 @@ function get_internal_entry( end end # Note: This assumes that the entry is in a FieldMatrixBroadcasted. else - unsupported_internal_entry_error(entry, name_pair) + throw(key_error) end end @@ -177,23 +199,42 @@ end function Base.similar(dict::FieldNameDict) entries = unrolled_map(values(dict)) do entry - entry isa UniformScaling ? entry : similar(entry) + entry isa ScalingFieldMatrixEntry ? entry : similar(entry) end return FieldNameDict(keys(dict), entries) end +# TODO: The behavior of this method is extremely counterintuitive---it is +# zeroing out mutable values, but leaving nonzero immutable values unchanged. +# We should probably use a different function name for this method. function Base.zero(dict::FieldNameDict) entries = unrolled_map(values(dict)) do entry - entry isa UniformScaling ? entry : zero(entry) + entry isa ScalingFieldMatrixEntry ? entry : zero(entry) end return FieldNameDict(keys(dict), entries) end -# Note: This assumes that the matrix has the same row and column units, since I -# cannot be multiplied by anything other than a scalar. function Base.one(matrix::FieldMatrix) - diagonal_keys = matrix_diagonal_keys(keys(matrix)) - return FieldNameDict(diagonal_keys, map(_ -> I, diagonal_keys)) + inferred_diagonal_keys = matrix_inferred_diagonal_keys(keys(matrix)) + entries = map(inferred_diagonal_keys) do key + if !(key in keys(matrix)) + I # default value for missing diagonal entries in a sparse matrix + else + # TODO: Add method for one(::Axis2Tensor) to simplify this. + T = + matrix[key] isa ScalingFieldMatrixEntry ? + eltype(matrix[key]) : eltype(eltype(matrix[key])) + if T <: Number + UniformScaling(one(T)) + elseif T <: Geometry.Axis2Tensor + tensor_data = UniformScaling(one(eltype(T))) + DiagonalMatrixRow(Geometry.AxisTensor(axes(T), tensor_data)) + else + error("Unsupported diagonal FieldMatrix entry type: $T") + end + end + end + return FieldNameDict(inferred_diagonal_keys, entries) end replace_name_tree(dict::FieldNameDict, name_tree) = @@ -210,7 +251,7 @@ end function check_diagonal_matrix(matrix, error_message_start = "The matrix") check_block_diagonal_matrix(matrix, error_message_start) non_diagonal_entry_pairs = unrolled_filter(pairs(matrix)) do pair - !(pair[2] isa UniformScaling || eltype(pair[2]) <: DiagonalMatrixRow) + !is_diagonal_matrix_entry(pair[2]) end non_diagonal_entry_keys = FieldMatrixKeys(unrolled_map(pair -> pair[1], non_diagonal_entry_pairs)) @@ -238,13 +279,45 @@ function lazy_main_diagonal(matrix) diagonal_keys = matrix_diagonal_keys(keys(matrix)) entries = map(diagonal_keys) do key entry = matrix[key] - entry isa UniformScaling || eltype(entry) <: DiagonalMatrixRow ? - entry : + is_diagonal_matrix_entry(entry) ? entry : Base.Broadcast.broadcasted(row -> DiagonalMatrixRow(row[0]), entry) end return FieldNameDict(diagonal_keys, entries) end +""" + identity_field_matrix(x) + +Constructs a `FieldMatrix` that represents the identity operator for the +`FieldVector` `x`. The keys of this `FieldMatrix` correspond to single values, +such as numbers and vectors. + +This offers an alternative to `one(matrix)`, which is not guaranteed to have all +the entries required to solve `matrix * x = b` for `x` if `matrix` is sparse. +""" +function identity_field_matrix(x::Fields.FieldVector) + single_field_names = filtered_names(x) do field + field isa Fields.Field && eltype(field) <: Geometry.SingleValue + end + single_field_keys = FieldVectorKeys(single_field_names, FieldNameTree(x)) + entries = map(single_field_keys) do name + # This must be consistent with the definition of one(::FieldMatrix). + T = eltype(get_field(x, name)) + if T <: Number + UniformScaling(one(T)) + elseif T <: Geometry.AxisVector + # TODO: Add methods for +(::UniformScaling, ::Axis2Tensor) and + # -(::UniformScaling, ::Axis2Tensor) to simplify this. + tensor_axes = (axes(T)[1], Geometry.dual(axes(T)[1])) + tensor_data = UniformScaling(one(eltype(T))) + DiagonalMatrixRow(Geometry.AxisTensor(tensor_axes, tensor_data)) + else + I # default value for elements that are neither scalars nor vectors + end + end + return FieldNameDict(corresponding_matrix_keys(single_field_keys), entries) +end + """ field_vector_view(x, [name_tree]) @@ -253,24 +326,11 @@ Constructs a `FieldVectorView` that contains all of the `Field`s in the be modified if needed. """ function field_vector_view(x, name_tree = FieldNameTree(x)) - keys_of_fields = FieldVectorKeys(names_of_fields(x, name_tree), name_tree) - entries = map(name -> get_field(x, name), keys_of_fields) - return FieldNameDict(keys_of_fields, entries) + field_names = filtered_names(field -> field isa Fields.Field, x) + field_keys = FieldVectorKeys(field_names, name_tree) + entries = map(name -> get_field(x, name), field_keys) + return FieldNameDict(field_keys, entries) end -names_of_fields(x, name_tree) = - unrolled_flatmap(top_level_names(x)) do name - entry = get_field(x, name) - if entry isa Fields.Field - (name,) - elseif entry isa Fields.FieldVector - unrolled_map(names_of_fields(entry, name_tree)) do internal_name - append_internal_name(name, internal_name) - end - else - error("field_vector_view does not support entries of type \ - $(typeof(entry).name.name)") - end - end """ concrete_field_vector(vector) @@ -302,9 +362,6 @@ concrete_field_vector_within_subtree(tree, vector) = # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true - for m in methods(names_of_fields) - m.recursion_relation = dont_limit - end for m in methods(concrete_field_vector_within_subtree) m.recursion_relation = dont_limit end @@ -319,11 +376,22 @@ const FieldVectorStyleType = Union{ Base.Broadcast.Broadcasted{<:Fields.FieldVectorStyle}, } +const SingleValueStyle = + Union{Base.Broadcast.DefaultArrayStyle{0}, Base.Broadcast.Style{Tuple}} + +const SingleValueStyleType = Union{ + Number, + Tuple{Geometry.SingleValue}, + Base.Broadcast.Broadcasted{<:SingleValueStyle}, +} + Base.Broadcast.broadcastable(vector_or_matrix::FieldNameDict) = vector_or_matrix Base.Broadcast.BroadcastStyle(::Type{<:FieldNameDict}) = FieldNameDictStyle() Base.Broadcast.BroadcastStyle(::FieldNameDictStyle, ::Fields.FieldVectorStyle) = FieldNameDictStyle() +Base.Broadcast.BroadcastStyle(::FieldNameDictStyle, ::SingleValueStyle) = + FieldNameDictStyle() function field_matrix_broadcast_error(f, args...) arg_string(::FieldVectorView) = "" @@ -365,14 +433,42 @@ Base.Broadcast.broadcasted( arg::FieldNameDict, ) = arg +# Add support for multiplication and division by single values. +function Base.Broadcast.broadcasted( + ::FieldNameDictStyle, + f::Union{typeof(*), typeof(/), typeof(\)}, + single_value_or_bc::SingleValueStyleType, + vector_or_matrix::FieldNameDict, +) + single_value = Base.Broadcast.materialize(single_value_or_bc) + entries = unrolled_map(values(vector_or_matrix)) do entry + entry isa ScalingFieldMatrixEntry ? f(single_value, entry) : + Base.Broadcast.broadcasted(f, single_value, entry) + end + return FieldNameDict(keys(vector_or_matrix), entries) +end +function Base.Broadcast.broadcasted( + ::FieldNameDictStyle, + f::Union{typeof(*), typeof(/), typeof(\)}, + vector_or_matrix::FieldNameDict, + single_value_or_bc::SingleValueStyleType, +) + single_value = Base.Broadcast.materialize(single_value_or_bc) + entries = unrolled_map(values(vector_or_matrix)) do entry + entry isa ScalingFieldMatrixEntry ? f(entry, single_value) : + Base.Broadcast.broadcasted(f, entry, single_value) + end + return FieldNameDict(keys(vector_or_matrix), entries) +end + function Base.Broadcast.broadcasted( ::FieldNameDictStyle, ::typeof(zero), vector_or_matrix::FieldNameDict, ) entries = unrolled_map(values(vector_or_matrix)) do entry - entry isa UniformScaling ? zero(entry) : - Base.Broadcast.broadcasted(value -> rzero(typeof(value)), entry) + entry isa ScalingFieldMatrixEntry ? zero(entry) : + Base.Broadcast.broadcasted(zero, entry) end return FieldNameDict(keys(vector_or_matrix), entries) end @@ -383,7 +479,8 @@ function Base.Broadcast.broadcasted( vector_or_matrix::FieldNameDict, ) entries = unrolled_map(values(vector_or_matrix)) do entry - entry isa UniformScaling ? -entry : Base.Broadcast.broadcasted(-, entry) + entry isa ScalingFieldMatrixEntry ? -entry : + Base.Broadcast.broadcasted(-, entry) end return FieldNameDict(keys(vector_or_matrix), entries) end @@ -399,11 +496,14 @@ function Base.Broadcast.broadcasted( if key in intersect(keys(vector_or_matrix1), keys(vector_or_matrix2)) entry1 = vector_or_matrix1[key] entry2 = vector_or_matrix2[key] - if entry1 isa UniformScaling && entry2 isa UniformScaling + if ( + entry1 isa ScalingFieldMatrixEntry && + entry2 isa ScalingFieldMatrixEntry + ) f(entry1, entry2) - elseif entry1 isa UniformScaling + elseif entry1 isa ScalingFieldMatrixEntry Base.Broadcast.broadcasted(f, (entry1,), entry2) - elseif entry2 isa UniformScaling + elseif entry2 isa ScalingFieldMatrixEntry Base.Broadcast.broadcasted(f, entry1, (entry2,)) else Base.Broadcast.broadcasted(f, entry1, entry2) @@ -415,7 +515,7 @@ function Base.Broadcast.broadcasted( vector_or_matrix2[key] else entry = vector_or_matrix2[key] - entry isa UniformScaling ? -entry : + entry isa ScalingFieldMatrixEntry ? -entry : Base.Broadcast.broadcasted(-, entry) end end @@ -440,12 +540,18 @@ function Base.Broadcast.broadcasted( key1, key2 = matrix_product_argument_keys(product_key, summand_name) entry1 = matrix[key1] entry2 = vector_or_matrix[key2] - if entry1 isa UniformScaling && entry2 isa UniformScaling - entry1 * entry2 - elseif entry1 isa UniformScaling - Base.Broadcast.broadcasted(*, entry1.λ, entry2) - elseif entry2 isa UniformScaling - Base.Broadcast.broadcasted(*, entry1, entry2.λ) + if ( + entry1 isa ScalingFieldMatrixEntry && + entry2 isa ScalingFieldMatrixEntry + ) + product_value = scaling_value(entry1) * scaling_value(entry2) + product_value isa Number ? + UniformScaling(product_value) : + DiagonalMatrixRow(product_value) + elseif entry1 isa ScalingFieldMatrixEntry + Base.Broadcast.broadcasted(*, (scaling_value(entry1),), entry2) + elseif entry2 isa ScalingFieldMatrixEntry + Base.Broadcast.broadcasted(*, entry1, (scaling_value(entry2),)) else Base.Broadcast.broadcasted(⋅, entry1, entry2) end @@ -471,7 +577,7 @@ function Base.Broadcast.broadcasted( "inv.() cannot be computed because the matrix", ) entries = unrolled_map(values(matrix)) do entry - entry isa UniformScaling ? inv(entry) : + entry isa ScalingFieldMatrixEntry ? inv(entry) : Base.Broadcast.broadcasted(inv, entry) end return FieldNameDict(keys(matrix), entries) @@ -536,9 +642,9 @@ NVTX.@annotate function copyto_foreach!( ) foreach(keys(vector_or_matrix)) do key entry = vector_or_matrix[key] - if dest[key] isa UniformScaling - dest[key] == entry || error("UniformScaling is immutable") - elseif entry isa UniformScaling + if dest[key] isa ScalingFieldMatrixEntry + dest[key] == entry || error("matrix entry at $key is immutable") + elseif entry isa ScalingFieldMatrixEntry dest[key] .= (entry,) else dest[key] .= entry diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl index 39bc5bcc41..677d2e09b1 100644 --- a/src/MatrixFields/field_name_set.jl +++ b/src/MatrixFields/field_name_set.jl @@ -144,13 +144,16 @@ function cartesian_product(row_set::FieldVectorKeys, col_set::FieldVectorKeys) return FieldMatrixKeys(result_values, name_tree) end -function matrix_row_keys(set::FieldMatrixKeys) - result_values′ = unrolled_map(name_pair -> name_pair[1], set.values) +function corresponding_vector_keys(set::FieldMatrixKeys, ::Val{N}) where {N} + result_values′ = unrolled_map(name_pair -> name_pair[N], set.values) result_values = unique_and_non_overlapping_values(result_values′, set.name_tree) return FieldVectorKeys(result_values, set.name_tree) end +matrix_row_keys(set::FieldMatrixKeys) = corresponding_vector_keys(set, Val(1)) +matrix_col_keys(set::FieldMatrixKeys) = corresponding_vector_keys(set, Val(2)) + function matrix_off_diagonal_keys(set::FieldMatrixKeys) result_values = unrolled_filter(name_pair -> name_pair[1] != name_pair[2], set.values) @@ -173,6 +176,16 @@ function matrix_diagonal_keys(set::FieldMatrixKeys) return FieldMatrixKeys(result_values, set.name_tree) end +function matrix_inferred_diagonal_keys(set::FieldMatrixKeys) + row_keys = matrix_row_keys(set) + col_keys = matrix_col_keys(set) + diag_keys = matrix_row_keys(matrix_diagonal_keys(set)) + all_keys = + issubset(row_keys, diag_keys) && issubset(col_keys, diag_keys) ? + diag_keys : union(row_keys, col_keys) # only compute the union if needed + return corresponding_matrix_keys(all_keys) +end + #= There are four cases that we need to support in order to be compatible with RecursiveApply (i.e., with rmul): @@ -187,7 +200,7 @@ RecursiveApply (i.e., with rmul): (name, name) * (name_child, _) -> (name_child, name_child) * (name_child, _) We are able to support this by extracting internal diagonal blocks from FieldNameDict entries. We can only extract an internal diagonal block from a - LinearAlgebra.UniformScaling or a ColumnwiseBandMatrixField of SingleValues. + ScalingFieldMatrixEntry or a ColumnwiseBandMatrixField of SingleValues. 4. (name1, name1) * name2 -> (name_child, name_child) * name_child or (name1, name1) * (name2, _) -> (name_child, name_child) * (name_child, _) This is a combination of cases 2 and 3, where "name_child" is a child name of diff --git a/src/MatrixFields/single_field_solver.jl b/src/MatrixFields/single_field_solver.jl index 59d3e21e61..93ee1b23cc 100644 --- a/src/MatrixFields/single_field_solver.jl +++ b/src/MatrixFields/single_field_solver.jl @@ -11,14 +11,14 @@ inv_return_type(::Type{X}) where {T, X <: Geometry.Axis2TensorOrAdj{T}} = Tuple{dual_type(Geometry.axis2(X)), dual_type(Geometry.axis1(X))}, ) -x_eltype(A::UniformScaling, b) = x_eltype(eltype(A), eltype(b)) +x_eltype(A::ScalingFieldMatrixEntry, b) = x_eltype(eltype(A), eltype(b)) x_eltype(A::ColumnwiseBandMatrixField, b) = x_eltype(eltype(eltype(A)), eltype(b)) x_eltype(::Type{T_A}, ::Type{T_b}) where {T_A, T_b} = rmul_return_type(inv_return_type(T_A), T_b) # Base.promote_op(rmul_with_projection, inv_return_type(T_A), T_b, LG) -unit_eltype(A::UniformScaling) = eltype(A) +unit_eltype(A::ScalingFieldMatrixEntry) = eltype(A) unit_eltype(A::ColumnwiseBandMatrixField) = unit_eltype(eltype(eltype(A)), local_geometry_type(A)) unit_eltype(::Type{T_A}, ::Type{LG}) where {T_A, LG} = @@ -27,7 +27,7 @@ unit_eltype(::Type{T_A}, ::Type{LG}) where {T_A, LG} = ################################################################################ -check_single_field_solver(::UniformScaling, _) = nothing +check_single_field_solver(::ScalingFieldMatrixEntry, _) = nothing function check_single_field_solver(A, b) matrix_shape(A) == Square() || error( "Cannot solve linear system because a diagonal entry in A is not a \ @@ -39,7 +39,7 @@ function check_single_field_solver(A, b) ) end -single_field_solver_cache(::UniformScaling, b) = similar(b, Tuple{}) +single_field_solver_cache(::ScalingFieldMatrixEntry, b) = similar(b, Tuple{}) function single_field_solver_cache(A::ColumnwiseBandMatrixField, b) ud = outer_diagonals(eltype(A))[2] cache_eltype = @@ -60,7 +60,8 @@ function single_field_solve_diag_matrix_row!( (A₀,) = Aⱼs @. x_vals = inv(A₀) ⊠ b_vals end -single_field_solve!(_, x, A::UniformScaling, b) = x .= inv(A.λ) .* b +single_field_solve!(_, x, A::ScalingFieldMatrixEntry, b) = + x .= (inv(scaling_value(A)),) .* b function single_field_solve!(cache, x, A::ColumnwiseBandMatrixField, b) if eltype(A) <: MatrixFields.DiagonalMatrixRow single_field_solve_diag_matrix_row!(cache, x, A, b) @@ -112,8 +113,8 @@ function _single_field_solve_col!( Fields.field_values(b), vindex, ) - elseif A isa UniformScaling - x .= inv(A.λ) .* b + elseif A isa ScalingFieldMatrixEntry + x .= (inv(scaling_value(A)),) .* b else error("uncaught case") end diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 305cf18a7b..e6f94e1464 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -1,4 +1,5 @@ import LinearAlgebra: I +import ClimaCore.RecursiveApply: rzero import ClimaCore.DataLayouts: replace_basetype import ClimaCore.MatrixFields: @name, is_subset_that_covers_set @@ -11,8 +12,9 @@ Base.propertynames(::Foo) = (:value,) Base.getproperty(foo::Foo, s::Symbol) = s == :value ? getfield(foo, :_value) : error("Invalid property name") Base.convert(::Type{Foo{T}}, foo::Foo) where {T} = Foo{T}(foo.value) +Base.zero(::Type{Foo{T}}) where {T} = Foo(zero(T)) -const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) +const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), (3, ())))) @testset "FieldName Unit Tests" begin @test_all @name() == MatrixFields.FieldName() @@ -618,7 +620,7 @@ end end @testset "Other FieldNameSet Operations" begin - # With one exception, none of the following operations require a + # With three exceptions, none of the following operations require a # FieldNameTree. @test_all MatrixFields.corresponding_matrix_keys(drop_tree(v_set1)) == @@ -641,6 +643,8 @@ end @test_all MatrixFields.matrix_row_keys(drop_tree(m_set1)) == vector_keys_no_tree(@name(foo), @name(a.b)) + @test_all MatrixFields.matrix_col_keys(drop_tree(m_set1)) == + vector_keys_no_tree(@name(foo), @name(a.c)) @test_all MatrixFields.matrix_row_keys(m_set4) == vector_keys_no_tree( @name(foo.value), @@ -649,9 +653,19 @@ end @name(a.c.:(2)), @name(a.c.:(3)) ) + @test_all MatrixFields.matrix_col_keys(m_set4) == vector_keys_no_tree( + @name(foo.value), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)) + ) + @test_throws "FieldNameTree" MatrixFields.matrix_row_keys( drop_tree(m_set4), ) + @test_throws "FieldNameTree" MatrixFields.matrix_col_keys( + drop_tree(m_set4), + ) @test_all MatrixFields.matrix_off_diagonal_keys(drop_tree(m_set4)) == matrix_keys_no_tree( @@ -667,126 +681,231 @@ end (@name(a.c.:(1)), @name(a.c.:(1))), (@name(a.c.:(3)), @name(a.c.:(3))), ) + + @test_all MatrixFields.matrix_inferred_diagonal_keys( + drop_tree(m_set1), + ) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(a.b), @name(a.b)), + (@name(a.c), @name(a.c)), + ) + + @test_all MatrixFields.matrix_inferred_diagonal_keys(m_set4) == + matrix_keys_no_tree( + (@name(foo.value), @name(foo.value)), + (@name(a.b), @name(a.b)), + (@name(a.c.:(1)), @name(a.c.:(1))), + (@name(a.c.:(2)), @name(a.c.:(2))), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + + @test_throws "FieldNameTree" MatrixFields.matrix_inferred_diagonal_keys( + drop_tree(m_set4), + ) end end @testset "FieldNameDict Unit Tests" begin FT = Float64 - center_space, face_space = test_spaces(FT) - x_FT = convert(replace_basetype(Int, FT, typeof(x)), x) + C3 = Geometry.Covariant3Vector{FT} + C12 = Geometry.Covariant12Vector{FT} + CT3 = Geometry.Contravariant3Vector{FT} + CT12 = Geometry.Contravariant12Vector{FT} + C12XC3 = typeof(zero(C12) * zero(C3)') + CT3XC3 = typeof(zero(CT3) * zero(C3)') + C12XCT12 = typeof(zero(C12) * zero(CT12)') + CT3XCT12 = typeof(zero(CT3) * zero(CT12)') + x_C12 = rzero(replace_basetype(Int, C12, typeof(x))) + x_CT3 = rzero(replace_basetype(Int, CT3, typeof(x))) + x_C12XC3 = rzero(replace_basetype(Int, C12XC3, typeof(x))) + x_CT3XCT12 = rzero(replace_basetype(Int, CT3XCT12, typeof(x))) + I_CT3XC3 = DiagonalMatrixRow(Geometry.AxisTensor(axes(CT3XC3), I)) + I_C12XCT12 = DiagonalMatrixRow(Geometry.AxisTensor(axes(C12XCT12), I)) + + center_space, face_space = test_spaces(FT) + seed!(1) # ensures reproducibility - vector = Fields.FieldVector(; + vector_of_scalars = Fields.FieldVector(; foo = random_field(typeof(x_FT.foo), center_space), a = random_field(typeof(x_FT.a), face_space), ) - matrix = MatrixFields.replace_name_tree( - MatrixFields.FieldMatrix( - (@name(foo), @name(foo)) => -I, - (@name(a), @name(a)) => - random_field(DiagonalMatrixRow{FT}, face_space), - (@name(foo), @name(a.b)) => random_field( - BidiagonalMatrixRow{typeof(x_FT.foo)}, - center_space, - ), - (@name(a), @name(foo._value)) => random_field( - QuaddiagonalMatrixRow{typeof(x_FT.a)}, - face_space, - ), + matrix_of_scalars = MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => + random_field(DiagonalMatrixRow{FT}, center_space), + (@name(foo), @name(a.b)) => random_field( + BidiagonalMatrixRow{typeof(x_FT.foo)}, + center_space, ), - MatrixFields.FieldNameTree(vector), - ) # Add a FieldNameTree in order to fully test the behavior of getindex. - - @test_all MatrixFields.field_vector_view(vector) == - MatrixFields.FieldVectorView( - @name(foo) => vector.foo, - @name(a) => vector.a, + (@name(a), @name(foo._value)) => + random_field(QuaddiagonalMatrixRow{typeof(x_FT.a)}, face_space), + (@name(a), @name(a)) => -I, ) - vector_view = MatrixFields.field_vector_view(vector) - - # Some of the `.*`s in the following RegEx strings are needed to account for - # module qualifiers that may or may not get printed, depending on how these - # tests are run. - - @test startswith( - string(vector_view), - r""" - .*FieldVectorView with 2 entries: - @name\(foo\) => .*-valued Field: - _value: \[.*\] - @name\(a\) => .*-valued Field: - """, + vector_of_vectors = Fields.FieldVector(; + foo = random_field(typeof(x_C12.foo), center_space), + a = random_field(typeof(x_CT3.a), face_space), ) - @test startswith( - string(matrix), - r""" - .*FieldMatrix with 4 entries: - \(@name\(foo\), @name\(foo\)\) => -I - \(@name\(a\), @name\(a\)\) => .*DiagonalMatrixRow{.*}-valued Field: - entries: \ - 1: \[.*\] - \(@name\(foo\), @name\(a.b\)\) => .*BidiagonalMatrixRow{.*}-valued Field: - entries: \ - 1: \ - _value: \[.*\] - 2: \ - _value: \[.*\] - \(@name\(a\), @name\(foo._value\)\) => .*QuaddiagonalMatrixRow{.*}-valued Field: - """, - ) broken = Sys.iswindows() - - @test_all vector_view[@name(foo)] == vector.foo - @test_throws KeyError vector_view[@name(invalid_name)] - @test_throws KeyError vector_view[@name(foo.invalid_name)] - - @test_all matrix[@name(foo), @name(foo)] == -I - @test_throws KeyError matrix[@name(invalid_name), @name(foo)] - @test_throws KeyError matrix[@name(foo.invalid_name), @name(foo)] - - @test_all vector_view[@name(foo._value)] == vector.foo._value - @test_all vector_view[@name(a.c)] == vector.a.c - - @test_all matrix[@name(foo._value), @name(foo._value)] == - matrix[@name(foo), @name(foo)] - @test_throws "get_internal_entry" matrix[@name(foo), @name(foo._value)] - @test_throws "get_internal_entry" matrix[@name(foo._value), @name(foo)] - - @test_all matrix[@name(a.c), @name(a.c)] == matrix[@name(a), @name(a)] - @test_throws "get_internal_entry" matrix[@name(a), @name(a.c)] - @test_throws "get_internal_entry" matrix[@name(a.c), @name(a)] - - @test_all matrix[@name(foo._value), @name(a.b)] isa Base.AbstractBroadcasted - @test Base.materialize(matrix[@name(foo._value), @name(a.b)]) == - map(row -> map(foo -> foo.value, row), matrix[@name(foo), @name(a.b)]) - - @test_all matrix[@name(a.c), @name(foo._value)] isa Base.AbstractBroadcasted - @test Base.materialize(matrix[@name(a.c), @name(foo._value)]) == - map(row -> map(a -> a.c, row), matrix[@name(a), @name(foo._value)]) - - vector_keys = MatrixFields.FieldVectorKeys((@name(foo), @name(a.c))) - @test_all vector_view[vector_keys] == MatrixFields.FieldVectorView( - @name(foo) => vector_view[@name(foo)], - @name(a.c) => vector_view[@name(a.c)], + matrix_of_tensors = MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => + random_field(DiagonalMatrixRow{C12XCT12}, center_space), + (@name(foo), @name(a.b)) => random_field( + BidiagonalMatrixRow{typeof(x_C12XC3.foo)}, + center_space, + ), + (@name(a), @name(foo._value)) => random_field( + QuaddiagonalMatrixRow{typeof(x_CT3XCT12.a)}, + face_space, + ), + (@name(a), @name(a)) => -I_CT3XC3, ) - matrix_keys = MatrixFields.FieldMatrixKeys(( - (@name(foo), @name(foo)), - (@name(a.c), @name(a.c)), - ),) - @test_all matrix[matrix_keys] == MatrixFields.FieldMatrix( - (@name(foo), @name(foo)) => matrix[@name(foo), @name(foo)], - (@name(a.c), @name(a.c)) => matrix[@name(a.c), @name(a.c)], + for (vector, matrix, I_foo, I_a) in ( + (vector_of_scalars, matrix_of_scalars, I, I), + (vector_of_vectors, matrix_of_tensors, I_C12XCT12, I_CT3XC3), ) + @test_all MatrixFields.field_vector_view(vector) == + MatrixFields.FieldVectorView( + @name(foo) => vector.foo, + @name(a) => vector.a, + ) - @test_all one(matrix) == MatrixFields.FieldMatrix( - (@name(foo), @name(foo)) => I, - (@name(a), @name(a)) => I, - ) + vector_view = MatrixFields.field_vector_view(vector) + + matrix_with_tree = MatrixFields.replace_name_tree( + matrix, + MatrixFields.FieldNameTree(vector), + ) + + # Some of the `.*`s in the following RegEx strings are needed to account + # for module qualifiers that may or may not get printed, depending on + # how these tests are run. + + @test startswith( + string(vector_view), + r""" + .*FieldVectorView with 2 entries: + @name\(foo\) => .*-valued Field: + _value: (.|\n)* + @name\(a\) => .*-valued Field: + (.|\n)*""", + ) + + @test startswith( + string(matrix), + r""" + .*FieldMatrix with 4 entries: + \(@name\(foo\), @name\(foo\)\) => .*DiagonalMatrixRow{.*}-valued Field: + entries: \ + 1: (.|\n)* + \(@name\(foo\), @name\(a.b\)\) => .*BidiagonalMatrixRow{.*}-valued Field: + entries: \ + 1: \ + _value: (.|\n)* + 2: \ + _value: (.|\n)* + \(@name\(a\), @name\(foo._value\)\) => .*QuaddiagonalMatrixRow{.*}-valued Field: + entries: (.|\n)* + \(@name\(a\), @name\(a\)\) => .*I""", + ) broken = Sys.iswindows() + + @test_throws KeyError vector_view[@name(invalid_name)] + @test_throws KeyError vector_view[@name(a.invalid_name)] + + @test_all vector_view[@name(a)] == vector.a + @test_all vector_view[@name(a.c)] == vector.a.c + @test_all vector_view[@name(foo._value)] == vector.foo._value + + @test_throws KeyError matrix[@name(invalid_name), @name(invalid_name)] + @test_throws KeyError matrix[@name(invalid_name), @name(a)] + @test_throws KeyError matrix[@name(a), @name(invalid_name)] + @test_throws KeyError matrix[@name(a), @name(a.invalid_name)] + @test_throws KeyError matrix[@name(a.invalid_name), @name(a)] + + @test_throws KeyError matrix[@name(a), @name(a.c)] + @test_throws KeyError matrix[@name(a.c), @name(a)] + @test_throws KeyError matrix[@name(foo), @name(foo._value)] + @test_throws KeyError matrix[@name(foo._value), @name(foo)] + + @test_all matrix[@name(a), @name(a)] == -I_a + @test_all matrix[@name(a.c), @name(a.c)] == -I_a + @test_all matrix[@name(a.c), @name(a.b)] == zero(I_a) + @test_all matrix[@name(foo._value), @name(foo._value)] == + matrix[@name(foo), @name(foo)] + + @test_all matrix[@name(foo._value), @name(a.b)] isa + Base.AbstractBroadcasted + @test Base.materialize(matrix[@name(foo._value), @name(a.b)]) == map( + row -> map(foo -> foo.value, row), + matrix[@name(foo), @name(a.b)], + ) + + @test_all matrix[@name(a.c), @name(foo._value)] isa + Base.AbstractBroadcasted + @test Base.materialize(matrix[@name(a.c), @name(foo._value)]) == map( + row -> map(a -> a.c, row), + matrix[@name(a), @name(foo._value)], + ) + + vector_keys = MatrixFields.FieldVectorKeys((@name(foo), @name(a.c))) + @test_all vector_view[vector_keys] == MatrixFields.FieldVectorView( + @name(foo) => vector_view[@name(foo)], + @name(a.c) => vector_view[@name(a.c)], + ) + + matrix_keys1 = MatrixFields.FieldMatrixKeys(( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + ),) + @test_all matrix[matrix_keys1] == MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => matrix[@name(foo), @name(foo)], + (@name(foo), @name(a.b)) => matrix[@name(foo), @name(a.b)], + ) + + matrix_keys2 = MatrixFields.FieldMatrixKeys(( + (@name(foo), @name(foo)), + (@name(a.c), @name(a.c)), # child key of (@name(a), @name(a)) + ),) + @test_throws "FieldNameTree" matrix[matrix_keys2] + @test_all matrix_with_tree[matrix_keys2] == MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => matrix[@name(foo), @name(foo)], + (@name(a.c), @name(a.c)) => matrix[@name(a.c), @name(a.c)], + ) + + partial_identity_matrix1 = MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => I_foo, + (@name(a.b), @name(a.b)) => I, # default for inferred diagonal key + ) + @test_all one(matrix[matrix_keys1]) == partial_identity_matrix1 + + partial_identity_matrix2 = MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => I_foo, + (@name(a.c), @name(a.c)) => I_a, + ) + @test_all one(matrix_with_tree[matrix_keys2]) == + partial_identity_matrix2 + + identity_matrix1 = MatrixFields.FieldMatrix( + (@name(foo), @name(foo)) => I_foo, + (@name(a), @name(a)) => I_a, + ) + @test_throws "FieldNameTree" one(matrix) + @test_all one(matrix_with_tree) == identity_matrix1 + + identity_matrix2 = MatrixFields.FieldMatrix( + (@name(foo._value), @name(foo._value)) => I_foo, + (@name(a.b), @name(a.b)) => I_a, + (@name(a.c.:(1).d), @name(a.c.:(1).d)) => I_a, + (@name(a.c.:(3).:(1)), @name(a.c.:(3).:(1))) => I_a, + ) + @test_all MatrixFields.identity_field_matrix(vector) == identity_matrix2 + + # TODO: Modify == so that identity_matrix1 == identity_matrix2 is true. + end # FieldNameDict broadcast operations are tested in field_matrix_solvers.jl. end diff --git a/test/Operators/finitedifference/linsolve.jl b/test/Operators/finitedifference/linsolve.jl deleted file mode 100644 index 2c3f692c09..0000000000 --- a/test/Operators/finitedifference/linsolve.jl +++ /dev/null @@ -1,101 +0,0 @@ -#= -julia --project=.buildkite -using Revise; include(joinpath("test", "Operators", "finitedifference", "linsolve.jl")) -=# -using Test -using ClimaComms -ClimaComms.@import_required_backends -import ClimaCore - -using ClimaCore: - Geometry, Domains, Meshes, Topologies, Spaces, Fields, Quadratures - -FT = Float32 -radius = FT(1e7) -zmax = FT(1e4) -helem = npoly = 2 -velem = 4 - -hdomain = Domains.SphereDomain(radius) -hmesh = Meshes.EquiangularCubedSphere(hdomain, helem) -context = ClimaComms.SingletonCommsContext() -device = ClimaComms.device(context) -htopology = Topologies.Topology2D(context, hmesh) -quad = Quadratures.GLL{npoly + 1}() -hspace = Spaces.SpectralElementSpace2D(htopology, quad) - -vdomain = Domains.IntervalDomain( - Geometry.ZPoint{FT}(zero(FT)), - Geometry.ZPoint{FT}(zmax); - boundary_names = (:bottom, :top), -) -vmesh = Meshes.IntervalMesh(vdomain, nelems = velem) -center_space = Spaces.CenterFiniteDifferenceSpace(device, vmesh) - -#= -# TODO: Replace this with a space that includes topography. -center_space = Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace) -center_coords = Fields.coordinate_field(center_space) -face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space) -=# -face_space = Spaces.FaceFiniteDifferenceSpace(center_space) - -function test_linsolve!(x, A, b, update_matrix = false; kwargs...) - - FT = Spaces.undertype(axes(x.c)) - - (; ∂ᶜρₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ) = A - - is_momentum_var(symbol) = symbol in (:uₕ, :ρuₕ, :w, :ρw) - - # Compute Schur complement - # Compute xᶠ𝕄 - xᶜρ = x.c.ρ - bᶜρ = b.c.ρ - ᶜ𝕄_name = Base.filter(is_momentum_var, propertynames(x.c))[1] - xᶜ𝕄 = getproperty(x.c, ᶜ𝕄_name) - bᶜ𝕄 = getproperty(b.c, ᶜ𝕄_name) - ᶠ𝕄_name = Base.filter(is_momentum_var, propertynames(x.f))[1] - xᶠ𝕄 = getproperty(x.f, ᶠ𝕄_name).components.data.:1 - bᶠ𝕄 = getproperty(b.f, ᶠ𝕄_name).components.data.:1 - - @. xᶠ𝕄 = bᶠ𝕄 + (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ)) - - # Compute remaining components of x - @. xᶜρ = -bᶜρ + apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄) - return nothing -end - -import ClimaCore -include( - joinpath(pkgdir(ClimaCore), "examples", "hybrid", "schur_complement_W.jl"), -) -jacobi_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :no_∂ᶜp∂ᶜK, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact); -use_transform = false; - -Y = Fields.FieldVector( - c = map( - coord -> ( - ρ = Float32(0), - ρe = Float32(0), - uₕ = Geometry.Covariant12Vector(Float32(0), Float32(0)), - ), - Fields.coordinate_field(center_space), - ), - f = map( - _ -> (; w = Geometry.Covariant3Vector(Float32(0))), - Fields.coordinate_field(face_space), - ), -) - -b = similar(Y) -W = SchurComplementW(Y, use_transform, jacobi_flags) - -using JET -using Test - -@testset "JET test for `apply` in linsolve! kernel" begin - test_linsolve!(Y, W, b) # compile first - @test 0 == @allocated test_linsolve!(Y, W, b) - @test_opt test_linsolve!(Y, W, b) -end diff --git a/test/runtests.jl b/test/runtests.jl index a64d8b2843..ca51830c40 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,7 +60,6 @@ UnitTest("Spectral elem - sphere hyperdiff vec" ,"Operators/spectralelement/u UnitTest("FD ops - column" ,"Operators/finitedifference/unit_column.jl"), UnitTest("FD ops - opt" ,"Operators/finitedifference/opt.jl"), UnitTest("FD ops - wfact" ,"Operators/finitedifference/wfact.jl"), -UnitTest("FD ops - linsolve" ,"Operators/finitedifference/linsolve.jl"), UnitTest("Hybrid - 2D" ,"Operators/hybrid/unit_2d.jl"), UnitTest("Hybrid - 3D" ,"Operators/hybrid/unit_3d.jl"), UnitTest("Hybrid - dss opt" ,"Operators/hybrid/dss_opt.jl"),