23
23
24
24
import numpy as np
25
25
26
+ from meshmode .mesh import BTAG_ALL
26
27
import meshmode .mesh .generation as mgen
27
28
28
29
from pytools .obj_array import make_obj_array
29
30
30
31
from grudge import op , geometry as geo , DiscretizationCollection
31
- from grudge .dof_desc import DOFDesc
32
+ from grudge .discretization import make_discretization_collection
33
+ from grudge .dof_desc import DOFDesc , as_dofdesc
32
34
33
35
import pytest
34
36
35
37
from grudge .array_context import PytestPyOpenCLArrayContextFactory
36
38
from arraycontext import pytest_generate_tests_for_array_contexts
39
+
40
+ from grudge .trace_pair import bdry_trace_pair , bv_trace_pair
37
41
pytest_generate_tests = pytest_generate_tests_for_array_contexts (
38
42
[PytestPyOpenCLArrayContextFactory ])
39
43
47
51
@pytest .mark .parametrize ("form" , ["strong" , "weak" ])
48
52
@pytest .mark .parametrize ("dim" , [1 , 2 , 3 ])
49
53
@pytest .mark .parametrize ("order" , [2 , 3 ])
54
+ @pytest .mark .parametrize ("warp_mesh" , [False , True ])
50
55
@pytest .mark .parametrize (("vectorize" , "nested" ), [
51
56
(False , False ),
52
57
(True , False ),
53
58
(True , True )
54
59
])
55
60
def test_gradient (actx_factory , form , dim , order , vectorize , nested ,
56
- visualize = False ):
61
+ warp_mesh , visualize = False ):
57
62
actx = actx_factory ()
58
63
59
64
from pytools .convergence import EOCRecorder
60
65
eoc_rec = EOCRecorder ()
61
66
62
- for n in [4 , 6 , 8 ]:
63
- mesh = mgen .generate_regular_rect_mesh (
64
- a = (- 1 ,)* dim , b = (1 ,)* dim ,
65
- nelements_per_axis = (n ,)* dim )
67
+ for n in [8 , 12 , 16 ] if warp_mesh else [4 , 6 , 8 ]:
68
+ if warp_mesh :
69
+ if dim == 1 :
70
+ pytest .skip ("warped mesh in 1D not implemented" )
71
+ mesh = mgen .generate_warped_rect_mesh (
72
+ dim = dim , order = order , nelements_side = n )
73
+ else :
74
+ mesh = mgen .generate_regular_rect_mesh (
75
+ a = (- 1 ,)* dim , b = (1 ,)* dim ,
76
+ nelements_per_axis = (n ,)* dim )
66
77
67
- dcoll = DiscretizationCollection (actx , mesh , order = order )
78
+ dcoll = make_discretization_collection (actx , mesh , order = order )
68
79
69
80
def f (x ):
70
- result = dcoll . zeros ( actx ) + 1
81
+ result = 1
71
82
for i in range (dim - 1 ):
72
83
result = result * actx .np .sin (np .pi * x [i ])
73
84
result = result * actx .np .cos (np .pi / 2 * x [dim - 1 ])
@@ -89,14 +100,17 @@ def grad_f(x):
89
100
90
101
x = actx .thaw (dcoll .nodes ())
91
102
92
- if vectorize :
93
- u = make_obj_array ([(i + 1 )* f (x ) for i in range (dim )])
94
- else :
95
- u = f (x )
103
+ def vectorize_if_requested (vec ):
104
+ if vectorize :
105
+ return make_obj_array ([(i + 1 )* vec for i in range (dim )])
106
+ else :
107
+ return vec
108
+
109
+ u = vectorize_if_requested (f (x ))
96
110
97
111
def get_flux (u_tpair ):
98
112
dd = u_tpair .dd
99
- dd_allfaces = dd .with_dtag ("all_faces" )
113
+ dd_allfaces = dd .with_domain_tag ("all_faces" )
100
114
normal = geo .normal (actx , dcoll , dd )
101
115
u_avg = u_tpair .avg
102
116
if vectorize :
@@ -108,7 +122,11 @@ def get_flux(u_tpair):
108
122
flux = u_avg * normal
109
123
return op .project (dcoll , dd , dd_allfaces , flux )
110
124
111
- dd_allfaces = DOFDesc ("all_faces" )
125
+ dd_allfaces = as_dofdesc ("all_faces" )
126
+
127
+ bdry_dd = as_dofdesc (BTAG_ALL )
128
+ bdry_x = actx .thaw (dcoll .nodes (bdry_dd ))
129
+ bdry_u = vectorize_if_requested (f (bdry_x ))
112
130
113
131
if form == "strong" :
114
132
grad_u = (
@@ -121,9 +139,9 @@ def get_flux(u_tpair):
121
139
+ # noqa: W504
122
140
op .face_mass (dcoll ,
123
141
dd_allfaces ,
124
- # Note: no boundary flux terms here because u_ext == u_int == 0
125
142
sum (get_flux (utpair )
126
143
for utpair in op .interior_trace_pairs (dcoll , u ))
144
+ + get_flux (bv_trace_pair (dcoll , bdry_dd , u , bdry_u ))
127
145
)
128
146
)
129
147
else :
@@ -138,6 +156,9 @@ def get_flux(u_tpair):
138
156
expected_grad_u = grad_f (x )
139
157
140
158
if visualize :
159
+ # the code below does not handle the vectorized case
160
+ assert not vectorize
161
+
141
162
from grudge .shortcuts import make_visualizer
142
163
vis = make_visualizer (dcoll , vis_order = order if dim == 3 else dim + 3 )
143
164
0 commit comments