Skip to content

Commit 2315fec

Browse files
committed
test_gradient: also test on warped mesh
1 parent 0da9e04 commit 2315fec

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

test/test_op.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,21 @@
2323

2424
import numpy as np
2525

26+
from meshmode.mesh import BTAG_ALL
2627
import meshmode.mesh.generation as mgen
2728

2829
from pytools.obj_array import make_obj_array
2930

3031
from grudge import op, geometry as geo, DiscretizationCollection
31-
from grudge.dof_desc import DOFDesc
32+
from grudge.discretization import make_discretization_collection
33+
from grudge.dof_desc import DOFDesc, as_dofdesc
3234

3335
import pytest
3436

3537
from grudge.array_context import PytestPyOpenCLArrayContextFactory
3638
from arraycontext import pytest_generate_tests_for_array_contexts
39+
40+
from grudge.trace_pair import bdry_trace_pair, bv_trace_pair
3741
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
3842
[PytestPyOpenCLArrayContextFactory])
3943

@@ -47,27 +51,34 @@
4751
@pytest.mark.parametrize("form", ["strong", "weak"])
4852
@pytest.mark.parametrize("dim", [1, 2, 3])
4953
@pytest.mark.parametrize("order", [2, 3])
54+
@pytest.mark.parametrize("warp_mesh", [False, True])
5055
@pytest.mark.parametrize(("vectorize", "nested"), [
5156
(False, False),
5257
(True, False),
5358
(True, True)
5459
])
5560
def test_gradient(actx_factory, form, dim, order, vectorize, nested,
56-
visualize=False):
61+
warp_mesh, visualize=False):
5762
actx = actx_factory()
5863

5964
from pytools.convergence import EOCRecorder
6065
eoc_rec = EOCRecorder()
6166

62-
for n in [4, 6, 8]:
63-
mesh = mgen.generate_regular_rect_mesh(
64-
a=(-1,)*dim, b=(1,)*dim,
65-
nelements_per_axis=(n,)*dim)
67+
for n in [8, 12, 16] if warp_mesh else [4, 6, 8]:
68+
if warp_mesh:
69+
if dim == 1:
70+
pytest.skip("warped mesh in 1D not implemented")
71+
mesh = mgen.generate_warped_rect_mesh(
72+
dim=dim, order=order, nelements_side=n)
73+
else:
74+
mesh = mgen.generate_regular_rect_mesh(
75+
a=(-1,)*dim, b=(1,)*dim,
76+
nelements_per_axis=(n,)*dim)
6677

67-
dcoll = DiscretizationCollection(actx, mesh, order=order)
78+
dcoll = make_discretization_collection(actx, mesh, order=order)
6879

6980
def f(x):
70-
result = dcoll.zeros(actx) + 1
81+
result = 1
7182
for i in range(dim-1):
7283
result = result * actx.np.sin(np.pi*x[i])
7384
result = result * actx.np.cos(np.pi/2*x[dim-1])
@@ -89,14 +100,17 @@ def grad_f(x):
89100

90101
x = actx.thaw(dcoll.nodes())
91102

92-
if vectorize:
93-
u = make_obj_array([(i+1)*f(x) for i in range(dim)])
94-
else:
95-
u = f(x)
103+
def vectorize_if_requested(vec):
104+
if vectorize:
105+
return make_obj_array([(i+1)*vec for i in range(dim)])
106+
else:
107+
return vec
108+
109+
u = vectorize_if_requested(f(x))
96110

97111
def get_flux(u_tpair):
98112
dd = u_tpair.dd
99-
dd_allfaces = dd.with_dtag("all_faces")
113+
dd_allfaces = dd.with_domain_tag("all_faces")
100114
normal = geo.normal(actx, dcoll, dd)
101115
u_avg = u_tpair.avg
102116
if vectorize:
@@ -108,7 +122,11 @@ def get_flux(u_tpair):
108122
flux = u_avg * normal
109123
return op.project(dcoll, dd, dd_allfaces, flux)
110124

111-
dd_allfaces = DOFDesc("all_faces")
125+
dd_allfaces = as_dofdesc("all_faces")
126+
127+
bdry_dd = as_dofdesc(BTAG_ALL)
128+
bdry_x = actx.thaw(dcoll.nodes(bdry_dd))
129+
bdry_u = vectorize_if_requested(f(bdry_x))
112130

113131
if form == "strong":
114132
grad_u = (
@@ -121,9 +139,9 @@ def get_flux(u_tpair):
121139
+ # noqa: W504
122140
op.face_mass(dcoll,
123141
dd_allfaces,
124-
# Note: no boundary flux terms here because u_ext == u_int == 0
125142
sum(get_flux(utpair)
126143
for utpair in op.interior_trace_pairs(dcoll, u))
144+
+ get_flux(bv_trace_pair(dcoll, bdry_dd, u, bdry_u))
127145
)
128146
)
129147
else:
@@ -138,6 +156,9 @@ def get_flux(u_tpair):
138156
expected_grad_u = grad_f(x)
139157

140158
if visualize:
159+
# the code below does not handle the vectorized case
160+
assert not vectorize
161+
141162
from grudge.shortcuts import make_visualizer
142163
vis = make_visualizer(dcoll, vis_order=order if dim == 3 else dim+3)
143164

0 commit comments

Comments
 (0)