diff --git a/test/grid_geoflow.exo b/test/grid_geoflow.exo new file mode 100644 index 000000000..8abc65564 Binary files /dev/null and b/test/grid_geoflow.exo differ diff --git a/uxarray/constants.py b/uxarray/constants.py index e821de4d6..e7a1cdb69 100644 --- a/uxarray/constants.py +++ b/uxarray/constants.py @@ -16,9 +16,6 @@ # error tolerance, mainly in the intersection calculations. MACHINE_EPSILON = np.float64(np.finfo(float).eps) -ENABLE_JIT_CACHE = True -ENABLE_JIT = True - ENABLE_FMA = False GRID_DIMS = ["n_node", "n_edge", "n_face"] diff --git a/uxarray/grid/arcs.py b/uxarray/grid/arcs.py index c2ef5f009..6be158005 100644 --- a/uxarray/grid/arcs.py +++ b/uxarray/grid/arcs.py @@ -27,7 +27,7 @@ def _to_list(obj): return obj -@njit +@njit(cache=True) def _point_within_gca_body( angle, gca_cart, pt, GCRv0_lonlat, GCRv1_lonlat, pt_lonlat, is_directed ): @@ -244,7 +244,7 @@ def point_within_gca(pt, gca_cart, is_directed=False): return out -@njit +@njit(cache=True) def in_between(p, q, r) -> bool: """Determines whether the number q is between p and r. @@ -266,7 +266,7 @@ def in_between(p, q, r) -> bool: return p <= q <= r or r <= q <= p -@njit +@njit(cache=True) def _decide_pole_latitude(lat1, lat2): """Determine the pole latitude based on the latitudes of two points on a Great Circle Arc (GCA). diff --git a/uxarray/grid/area.py b/uxarray/grid/area.py index 13720e786..b785f87f8 100644 --- a/uxarray/grid/area.py +++ b/uxarray/grid/area.py @@ -2,13 +2,10 @@ from uxarray.grid.coordinates import _lonlat_rad_to_xyz -from numba import njit, config -from uxarray.constants import ENABLE_JIT_CACHE, ENABLE_JIT +from numba import njit -config.DISABLE_JIT = not ENABLE_JIT - -@njit(cache=ENABLE_JIT_CACHE) +@njit(cache=True) def calculate_face_area( x, y, z, quadrature_rule="gaussian", order=4, coords_type="spherical" ): @@ -98,7 +95,7 @@ def calculate_face_area( return area, jacobian -@njit(cache=ENABLE_JIT_CACHE) +@njit(cache=True) def get_all_face_area_from_coords( x, y, @@ -173,7 +170,7 @@ def get_all_face_area_from_coords( return area, jacobian -@njit(cache=ENABLE_JIT_CACHE) +@njit(cache=True) def calculate_spherical_triangle_jacobian(node1, node2, node3, dA, dB): """Calculate Jacobian of a spherical triangle. This is a helper function for calculating face area. @@ -263,7 +260,7 @@ def calculate_spherical_triangle_jacobian(node1, node2, node3, dA, dB): return dJacobian -@njit(cache=ENABLE_JIT_CACHE) +@njit(cache=True) def calculate_spherical_triangle_jacobian_barycentric(node1, node2, node3, dA, dB): """Calculate Jacobian of a spherical triangle. This is a helper function for calculating face area. @@ -342,7 +339,7 @@ def calculate_spherical_triangle_jacobian_barycentric(node1, node2, node3, dA, d return 0.5 * dJacobian -@njit(cache=ENABLE_JIT_CACHE) +@njit(cache=True) def get_gauss_quadratureDG(nCount): """Gauss Quadrature Points for integration. @@ -587,7 +584,7 @@ def get_gauss_quadratureDG(nCount): return dG, dW -@njit(cache=ENABLE_JIT_CACHE) +@njit(cache=True) def get_tri_quadratureDG(nOrder): """Triangular Quadrature Points for integration. diff --git a/uxarray/grid/connectivity.py b/uxarray/grid/connectivity.py index d94d9eba6..1fe9efbaf 100644 --- a/uxarray/grid/connectivity.py +++ b/uxarray/grid/connectivity.py @@ -142,7 +142,7 @@ def _populate_n_nodes_per_face(grid): ) -@njit() +@njit(cache=True) def _build_n_nodes_per_face(face_nodes, n_face, n_max_face_nodes): """Constructs ``n_nodes_per_face``, which contains the number of non-fill- value nodes for each face in ``face_node_connectivity``""" @@ -251,7 +251,7 @@ def _populate_edge_face_connectivity(grid): ) -@njit +@njit(cache=True) def _build_edge_face_connectivity(face_edges, n_nodes_per_face, n_edge): """Helper for (``edge_face_connectivity``) construction.""" edge_faces = np.ones(shape=(n_edge, 2), dtype=face_edges.dtype) * INT_FILL_VALUE diff --git a/uxarray/grid/coordinates.py b/uxarray/grid/coordinates.py index 415fa2fe1..399a03468 100644 --- a/uxarray/grid/coordinates.py +++ b/uxarray/grid/coordinates.py @@ -26,7 +26,7 @@ def _lonlat_rad_to_xyz( return x, y, z -@njit +@njit(cache=True) def _xyz_to_lonlat_rad_no_norm( x: Union[np.ndarray, float], y: Union[np.ndarray, float], @@ -67,7 +67,7 @@ def _xyz_to_lonlat_rad_no_norm( return lon, lat -@njit +@njit(cache=True) def _xyz_to_lonlat_rad_scalar( x: Union[np.ndarray, float], y: Union[np.ndarray, float], @@ -217,7 +217,7 @@ def _normalize_xyz( return x_norm, y_norm, z_norm -@njit +@njit(cache=True) def _normalize_xyz_scalar(x: float, y: float, z: float): denom = np.linalg.norm(np.asarray(np.array([x, y, z]), dtype=np.float64), ord=2) x_norm = x / denom @@ -430,7 +430,7 @@ def _smallest_enclosing_circle(points): return _welzl_recursive(points, np.empty((0, 2)), None) -@njit +@njit(cache=True) def _circle_from_two_points(p1, p2): """Calculate the smallest circle that encloses two points on a unit sphere. @@ -459,7 +459,7 @@ def _circle_from_two_points(p1, p2): return center, radius -@njit +@njit(cache=True) def _circle_from_three_points(p1, p2, p3): """Calculate the smallest circle that encloses three points on a unit sphere. This is a placeholder implementation. @@ -499,7 +499,7 @@ def _circle_from_three_points(p1, p2, p3): return center, radius -@njit +@njit(cache=True) def _is_inside_circle(circle, point): """Check if a point is inside a given circle on a unit sphere. @@ -763,7 +763,7 @@ def _xyz_to_lonlat_rad( return lon, lat -@njit +@njit(cache=True) def _xyz_to_lonlat_rad_no_norm( x: Union[np.ndarray, float], y: Union[np.ndarray, float], @@ -870,7 +870,7 @@ def _xyz_to_lonlat_deg( return lon, lat -@njit +@njit(cache=True) def _normalize_xyz_scalar(x: float, y: float, z: float): denom = np.linalg.norm(np.asarray(np.array([x, y, z]), dtype=np.float64), ord=2) x_norm = x / denom diff --git a/uxarray/grid/dual.py b/uxarray/grid/dual.py index 5dc5cb4f5..83506e4e4 100644 --- a/uxarray/grid/dual.py +++ b/uxarray/grid/dual.py @@ -2,7 +2,6 @@ from uxarray.constants import INT_FILL_VALUE, INT_DTYPE from numba import njit -from uxarray.constants import ENABLE_JIT_CACHE def construct_dual(grid): @@ -53,7 +52,7 @@ def construct_dual(grid): return new_node_face_connectivity -@njit(cache=ENABLE_JIT_CACHE) +@njit(cache=True) def construct_faces( n_node, n_edges, @@ -146,7 +145,7 @@ def construct_faces( return construct_node_face_connectivity -@njit(cache=ENABLE_JIT_CACHE) +@njit(cache=True) def _order_nodes( temp_face, node_0, diff --git a/uxarray/grid/geometry.py b/uxarray/grid/geometry.py index 09f5e50b0..c0949b9f3 100644 --- a/uxarray/grid/geometry.py +++ b/uxarray/grid/geometry.py @@ -12,7 +12,11 @@ from spatialpandas.geometry import MultiPolygonArray, PolygonArray import xarray as xr -from uxarray.constants import ERROR_TOLERANCE, INT_DTYPE, INT_FILL_VALUE +from uxarray.constants import ( + ERROR_TOLERANCE, + INT_DTYPE, + INT_FILL_VALUE, +) from uxarray.grid.arcs import extreme_gca_latitude, point_within_gca from uxarray.grid.intersections import gca_gca_intersection from uxarray.grid.utils import ( @@ -80,7 +84,7 @@ def error_radius(p1, p2): return unique_points -@njit +@njit(cache=True) def _pad_closed_face_nodes( face_node_connectivity, n_face, n_max_face_nodes, n_nodes_per_face ): diff --git a/uxarray/grid/neighbors.py b/uxarray/grid/neighbors.py index 4d48a0c88..2dcb70602 100644 --- a/uxarray/grid/neighbors.py +++ b/uxarray/grid/neighbors.py @@ -855,7 +855,7 @@ def _populate_edge_node_distances(grid): ) -@njit +@njit(cache=True) def _construct_edge_node_distances(node_lon, node_lat, edge_nodes): """Helper for computing the arc-distance between nodes compose each edge.""" @@ -890,7 +890,7 @@ def _populate_edge_face_distances(grid): ) -@njit +@njit(cache=True) def _construct_edge_face_distances(node_lon, node_lat, edge_faces): """Helper for computing the arc-distance between faces that saddle a given edge.""" diff --git a/uxarray/grid/utils.py b/uxarray/grid/utils.py index 63cb60213..1d826cf9e 100644 --- a/uxarray/grid/utils.py +++ b/uxarray/grid/utils.py @@ -6,7 +6,7 @@ from numba import njit -@njit +@njit(cache=True) def _angle_of_2_vectors(u, v): """Calculate the angle between two 3D vectors u and v in radians. Can be used to calcualte the span of a GCR. diff --git a/uxarray/utils/computing.py b/uxarray/utils/computing.py index 948431c9b..2dfa02dc2 100644 --- a/uxarray/utils/computing.py +++ b/uxarray/utils/computing.py @@ -4,7 +4,7 @@ from numba import njit -@njit +@njit(cache=True) def all(a): """Numba decorated implementation of ``np.all()`` @@ -16,7 +16,7 @@ def all(a): return np.all(a) -@njit +@njit(cache=True) def isclose(a, b, rtol=1e-05, atol=1e-08): """Numba decorated implementation of ``np.isclose()`` @@ -28,7 +28,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08): return np.isclose(a, b, rtol=rtol, atol=atol) -@njit +@njit(cache=True) def allclose(a, b, rtol=1e-05, atol=1e-08): """Numba decorated implementation of ``np.allclose()`` @@ -39,7 +39,7 @@ def allclose(a, b, rtol=1e-05, atol=1e-08): return np.allclose(a, b, rtol=rtol, atol=atol) -@njit +@njit(cache=True) def cross(a, b): """Numba decorated implementation of ``np.cross()`` @@ -50,7 +50,7 @@ def cross(a, b): return np.cross(a, b) -@njit +@njit(cache=True) def dot(a, b): """Numba decorated implementation of ``np.dot()`` @@ -61,7 +61,7 @@ def dot(a, b): return np.dot(a, b) -@njit +@njit(cache=True) def norm(x): """Numba decorated implementation of ``np.linalg.norm()`` diff --git a/uxarray/utils/numba_settings.py b/uxarray/utils/numba_settings.py deleted file mode 100644 index ea7e28ffc..000000000 --- a/uxarray/utils/numba_settings.py +++ /dev/null @@ -1,49 +0,0 @@ -import uxarray.constants - - -def enable_jit_cache(): - """Allows Numba's JIT cache to be turned on. - - This cache variable lets @njit cache the machine code generated - between runs, allowing for faster run times due to the fact that the - code doesn't need to regenerate the machine code every run time. Our - use case here was to study performance, in regular usage one might - never turn off caching as it will only help if frequently modifying - the code or because users have very limited disk space. The default - is on (True) - """ - uxarray.constants.ENABLE_JIT_CACHE = True - - -def disable_jit_cache(): - """Allows Numba's JIT cache to be turned on off. - - This cache variable lets @njit cache the machine code generated - between runs, allowing for faster run times due to the fact that the - code doesn't need to regenerate the machine code every run time. Our - use case here was to study performance, in regular usage one might - never turn off caching as it will only help if frequently modifying - the code or because users have very limited disk space. The default - is on (True) - """ - uxarray.constants.ENABLE_JIT_CACHE = False - - -def enable_jit(): - """Allows Numba's JIT application to be turned on. - - This lets users choose whether they want machine code to be - generated to speed up the performance of the code on large files. - The default is on (True) - """ - uxarray.constants.ENABLE_JIT = True - - -def disable_jit(): - """Allows Numba's JIT application to be turned off. - - This lets users choose whether they want machine code to be - generated to speed up the performance of the code on large files. - The default is on (True) - """ - uxarray.constants.ENABLE_JIT = False