Skip to content

Commit

Permalink
JIT Cleanup (#1007)
Browse files Browse the repository at this point in the history
* Initial Commit

* removed ENABLE_JIT_CACHE
  • Loading branch information
aaronzedwick authored Oct 12, 2024
1 parent 691999b commit bd8901f
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 89 deletions.
Binary file added test/grid_geoflow.exo
Binary file not shown.
3 changes: 0 additions & 3 deletions uxarray/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
# error tolerance, mainly in the intersection calculations.
MACHINE_EPSILON = np.float64(np.finfo(float).eps)

ENABLE_JIT_CACHE = True
ENABLE_JIT = True

ENABLE_FMA = False

GRID_DIMS = ["n_node", "n_edge", "n_face"]
Expand Down
6 changes: 3 additions & 3 deletions uxarray/grid/arcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _to_list(obj):
return obj


@njit
@njit(cache=True)
def _point_within_gca_body(
angle, gca_cart, pt, GCRv0_lonlat, GCRv1_lonlat, pt_lonlat, is_directed
):
Expand Down Expand Up @@ -244,7 +244,7 @@ def point_within_gca(pt, gca_cart, is_directed=False):
return out


@njit
@njit(cache=True)
def in_between(p, q, r) -> bool:
"""Determines whether the number q is between p and r.
Expand All @@ -266,7 +266,7 @@ def in_between(p, q, r) -> bool:
return p <= q <= r or r <= q <= p


@njit
@njit(cache=True)
def _decide_pole_latitude(lat1, lat2):
"""Determine the pole latitude based on the latitudes of two points on a
Great Circle Arc (GCA).
Expand Down
17 changes: 7 additions & 10 deletions uxarray/grid/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

from uxarray.grid.coordinates import _lonlat_rad_to_xyz

from numba import njit, config
from uxarray.constants import ENABLE_JIT_CACHE, ENABLE_JIT
from numba import njit

config.DISABLE_JIT = not ENABLE_JIT


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def calculate_face_area(
x, y, z, quadrature_rule="gaussian", order=4, coords_type="spherical"
):
Expand Down Expand Up @@ -98,7 +95,7 @@ def calculate_face_area(
return area, jacobian


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def get_all_face_area_from_coords(
x,
y,
Expand Down Expand Up @@ -173,7 +170,7 @@ def get_all_face_area_from_coords(
return area, jacobian


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def calculate_spherical_triangle_jacobian(node1, node2, node3, dA, dB):
"""Calculate Jacobian of a spherical triangle. This is a helper function
for calculating face area.
Expand Down Expand Up @@ -263,7 +260,7 @@ def calculate_spherical_triangle_jacobian(node1, node2, node3, dA, dB):
return dJacobian


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def calculate_spherical_triangle_jacobian_barycentric(node1, node2, node3, dA, dB):
"""Calculate Jacobian of a spherical triangle. This is a helper function
for calculating face area.
Expand Down Expand Up @@ -342,7 +339,7 @@ def calculate_spherical_triangle_jacobian_barycentric(node1, node2, node3, dA, d
return 0.5 * dJacobian


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def get_gauss_quadratureDG(nCount):
"""Gauss Quadrature Points for integration.
Expand Down Expand Up @@ -587,7 +584,7 @@ def get_gauss_quadratureDG(nCount):
return dG, dW


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def get_tri_quadratureDG(nOrder):
"""Triangular Quadrature Points for integration.
Expand Down
4 changes: 2 additions & 2 deletions uxarray/grid/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _populate_n_nodes_per_face(grid):
)


@njit()
@njit(cache=True)
def _build_n_nodes_per_face(face_nodes, n_face, n_max_face_nodes):
"""Constructs ``n_nodes_per_face``, which contains the number of non-fill-
value nodes for each face in ``face_node_connectivity``"""
Expand Down Expand Up @@ -251,7 +251,7 @@ def _populate_edge_face_connectivity(grid):
)


@njit
@njit(cache=True)
def _build_edge_face_connectivity(face_edges, n_nodes_per_face, n_edge):
"""Helper for (``edge_face_connectivity``) construction."""
edge_faces = np.ones(shape=(n_edge, 2), dtype=face_edges.dtype) * INT_FILL_VALUE
Expand Down
16 changes: 8 additions & 8 deletions uxarray/grid/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _lonlat_rad_to_xyz(
return x, y, z


@njit
@njit(cache=True)
def _xyz_to_lonlat_rad_no_norm(
x: Union[np.ndarray, float],
y: Union[np.ndarray, float],
Expand Down Expand Up @@ -67,7 +67,7 @@ def _xyz_to_lonlat_rad_no_norm(
return lon, lat


@njit
@njit(cache=True)
def _xyz_to_lonlat_rad_scalar(
x: Union[np.ndarray, float],
y: Union[np.ndarray, float],
Expand Down Expand Up @@ -217,7 +217,7 @@ def _normalize_xyz(
return x_norm, y_norm, z_norm


@njit
@njit(cache=True)
def _normalize_xyz_scalar(x: float, y: float, z: float):
denom = np.linalg.norm(np.asarray(np.array([x, y, z]), dtype=np.float64), ord=2)
x_norm = x / denom
Expand Down Expand Up @@ -430,7 +430,7 @@ def _smallest_enclosing_circle(points):
return _welzl_recursive(points, np.empty((0, 2)), None)


@njit
@njit(cache=True)
def _circle_from_two_points(p1, p2):
"""Calculate the smallest circle that encloses two points on a unit sphere.
Expand Down Expand Up @@ -459,7 +459,7 @@ def _circle_from_two_points(p1, p2):
return center, radius


@njit
@njit(cache=True)
def _circle_from_three_points(p1, p2, p3):
"""Calculate the smallest circle that encloses three points on a unit
sphere. This is a placeholder implementation.
Expand Down Expand Up @@ -499,7 +499,7 @@ def _circle_from_three_points(p1, p2, p3):
return center, radius


@njit
@njit(cache=True)
def _is_inside_circle(circle, point):
"""Check if a point is inside a given circle on a unit sphere.
Expand Down Expand Up @@ -763,7 +763,7 @@ def _xyz_to_lonlat_rad(
return lon, lat


@njit
@njit(cache=True)
def _xyz_to_lonlat_rad_no_norm(
x: Union[np.ndarray, float],
y: Union[np.ndarray, float],
Expand Down Expand Up @@ -870,7 +870,7 @@ def _xyz_to_lonlat_deg(
return lon, lat


@njit
@njit(cache=True)
def _normalize_xyz_scalar(x: float, y: float, z: float):
denom = np.linalg.norm(np.asarray(np.array([x, y, z]), dtype=np.float64), ord=2)
x_norm = x / denom
Expand Down
5 changes: 2 additions & 3 deletions uxarray/grid/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uxarray.constants import INT_FILL_VALUE, INT_DTYPE

from numba import njit
from uxarray.constants import ENABLE_JIT_CACHE


def construct_dual(grid):
Expand Down Expand Up @@ -53,7 +52,7 @@ def construct_dual(grid):
return new_node_face_connectivity


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def construct_faces(
n_node,
n_edges,
Expand Down Expand Up @@ -146,7 +145,7 @@ def construct_faces(
return construct_node_face_connectivity


@njit(cache=ENABLE_JIT_CACHE)
@njit(cache=True)
def _order_nodes(
temp_face,
node_0,
Expand Down
8 changes: 6 additions & 2 deletions uxarray/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from spatialpandas.geometry import MultiPolygonArray, PolygonArray
import xarray as xr

from uxarray.constants import ERROR_TOLERANCE, INT_DTYPE, INT_FILL_VALUE
from uxarray.constants import (
ERROR_TOLERANCE,
INT_DTYPE,
INT_FILL_VALUE,
)
from uxarray.grid.arcs import extreme_gca_latitude, point_within_gca
from uxarray.grid.intersections import gca_gca_intersection
from uxarray.grid.utils import (
Expand Down Expand Up @@ -80,7 +84,7 @@ def error_radius(p1, p2):
return unique_points


@njit
@njit(cache=True)
def _pad_closed_face_nodes(
face_node_connectivity, n_face, n_max_face_nodes, n_nodes_per_face
):
Expand Down
4 changes: 2 additions & 2 deletions uxarray/grid/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def _populate_edge_node_distances(grid):
)


@njit
@njit(cache=True)
def _construct_edge_node_distances(node_lon, node_lat, edge_nodes):
"""Helper for computing the arc-distance between nodes compose each
edge."""
Expand Down Expand Up @@ -890,7 +890,7 @@ def _populate_edge_face_distances(grid):
)


@njit
@njit(cache=True)
def _construct_edge_face_distances(node_lon, node_lat, edge_faces):
"""Helper for computing the arc-distance between faces that saddle a given
edge."""
Expand Down
2 changes: 1 addition & 1 deletion uxarray/grid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numba import njit


@njit
@njit(cache=True)
def _angle_of_2_vectors(u, v):
"""Calculate the angle between two 3D vectors u and v in radians. Can be
used to calcualte the span of a GCR.
Expand Down
12 changes: 6 additions & 6 deletions uxarray/utils/computing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numba import njit


@njit
@njit(cache=True)
def all(a):
"""Numba decorated implementation of ``np.all()``
Expand All @@ -16,7 +16,7 @@ def all(a):
return np.all(a)


@njit
@njit(cache=True)
def isclose(a, b, rtol=1e-05, atol=1e-08):
"""Numba decorated implementation of ``np.isclose()``
Expand All @@ -28,7 +28,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08):
return np.isclose(a, b, rtol=rtol, atol=atol)


@njit
@njit(cache=True)
def allclose(a, b, rtol=1e-05, atol=1e-08):
"""Numba decorated implementation of ``np.allclose()``
Expand All @@ -39,7 +39,7 @@ def allclose(a, b, rtol=1e-05, atol=1e-08):
return np.allclose(a, b, rtol=rtol, atol=atol)


@njit
@njit(cache=True)
def cross(a, b):
"""Numba decorated implementation of ``np.cross()``
Expand All @@ -50,7 +50,7 @@ def cross(a, b):
return np.cross(a, b)


@njit
@njit(cache=True)
def dot(a, b):
"""Numba decorated implementation of ``np.dot()``
Expand All @@ -61,7 +61,7 @@ def dot(a, b):
return np.dot(a, b)


@njit
@njit(cache=True)
def norm(x):
"""Numba decorated implementation of ``np.linalg.norm()``
Expand Down
49 changes: 0 additions & 49 deletions uxarray/utils/numba_settings.py

This file was deleted.

0 comments on commit bd8901f

Please sign in to comment.