Skip to content

Commit

Permalink
Do not store grids, destroy them after usage (#402)
Browse files Browse the repository at this point in the history
* Destroy grid one used, do not store it

* upd changes

* Release mesh memory too

---------

Co-authored-by: raphael dussin <[email protected]>
  • Loading branch information
aulemahal and raphaeldussin authored Feb 7, 2025
1 parent 12fd94a commit fceddc8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 37 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
What's new
==========

0.8.9 (unreleased)
------------------
* Destroy grids explicitly once weights are computed. Do not store them in `grid_in` and `grid_out` attributes. This fixes segmentation faults introduced by the memory fix of last version. By `Pascal Bourgault <https://github.com/aulemahal>`_.

0.8.8 (2024-11-01)
------------------
* Fix ESMpy memory issues by explictly freeing the Grid memory upon garbage collection of ``Regridder`` objects. By `Pascal Bourgault <https://github.com/aulemahal>`_.
Expand Down
6 changes: 1 addition & 5 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,19 +546,15 @@ def esmf_regrid_finalize(regrid):
regrid : ESMF.Regrid object
"""

# We do not destroy the Grids here, as they might be reused between multiple regrids
regrid.destroy()
regrid.srcfield.destroy()
regrid.dstfield.destroy()
# regrid.srcfield.grid.destroy()
# regrid.dstfield.grid.destroy()

# double check
assert regrid.finalized
assert regrid.srcfield.finalized
assert regrid.dstfield.finalized
# assert regrid.srcfield.grid.finalized
# assert regrid.dstfield.grid.finalized


# Deprecated as of version 0.5.0
Expand Down
55 changes: 23 additions & 32 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,17 +334,15 @@ def __init__(
baseregridder : xESMF BaseRegridder object
"""
self.grid_in = grid_in
self.grid_out = grid_out
self.method = method
self.reuse_weights = reuse_weights
self.extrap_method = extrap_method
self.extrap_dist_exponent = extrap_dist_exponent
self.extrap_num_src_pnts = extrap_num_src_pnts
self.ignore_degenerate = ignore_degenerate
self.periodic = getattr(self.grid_in, 'periodic_dim', None) is not None
self.sequence_in = isinstance(self.grid_in, (LocStream, Mesh))
self.sequence_out = isinstance(self.grid_out, (LocStream, Mesh))
self.periodic = getattr(grid_in, 'periodic_dim', None) is not None
self.sequence_in = isinstance(grid_in, (LocStream, Mesh))
self.sequence_out = isinstance(grid_out, (LocStream, Mesh))

if input_dims is not None and len(input_dims) != int(not self.sequence_in) + 1:
raise ValueError(f'Wrong number of dimension names in `input_dims` ({len(input_dims)}.')
Expand All @@ -358,8 +356,8 @@ def __init__(

# record grid shape information
# We need to invert Grid shapes to respect xESMF's convention (y, x).
self.shape_in = self.grid_in.get_shape()[::-1]
self.shape_out = self.grid_out.get_shape()[::-1]
self.shape_in = grid_in.get_shape()[::-1]
self.shape_out = grid_out.get_shape()[::-1]
self.n_in = self.shape_in[0] * self.shape_in[1]
self.n_out = self.shape_out[0] * self.shape_out[1]

Expand All @@ -369,7 +367,7 @@ def __init__(

if not parallel:
if not reuse_weights and weights is None:
weights = self._compute_weights() # Dictionary of weights
weights = self._compute_weights(grid_in, grid_out) # Dictionary of weights
else:
weights = filename if filename is not None else weights

Expand All @@ -380,7 +378,7 @@ def __init__(

# replace zeros by NaN for weight matrix entries of unmapped target cells if specified or a mask is present
if (
(self.grid_out.mask is not None) and (self.grid_out.mask[0] is not None)
(grid_out.mask is not None) and (grid_out.mask[0] is not None)
) or unmapped_to_nan is True:
self.weights = add_nans_to_weights(self.weights)

Expand Down Expand Up @@ -435,10 +433,10 @@ def _get_default_filename(self):

return filename

def _compute_weights(self):
def _compute_weights(self, grid_in, grid_out):
regrid = esmf_regrid_build(
self.grid_in,
self.grid_out,
grid_in,
grid_out,
self.method,
extrap_method=self.extrap_method,
extrap_dist_exponent=self.extrap_dist_exponent,
Expand Down Expand Up @@ -934,6 +932,9 @@ def __init__(
parallel=parallel,
**kwargs,
)
# Weights are computed, we do not need the grids anymore
grid_in.destroy()
grid_out.destroy()

# Record output grid and metadata
lon_out, lat_out = _get_lon_lat(ds_out)
Expand Down Expand Up @@ -1109,13 +1110,6 @@ def _format_xroutput(self, out, new_dims=None):

return out

def __del__(self):
# Memory leak issue when regridding over a large number of datasets with xESMF
# https://github.com/JiaweiZhuang/xESMF/issues/53
if hasattr(self, 'grid_in'): # If the init has failed, grid_in isn't there
self.grid_in.destroy()
self.grid_out.destroy()


class SpatialAverager(BaseRegridder):
def __init__(
Expand Down Expand Up @@ -1249,6 +1243,9 @@ def __init__(
ignore_degenerate=ignore_degenerate,
unmapped_to_nan=False,
)
# Weights are computed, we do not need the grids anymore
grid_in.destroy()
locstream_out.destroy()

@staticmethod
def _check_polys_length(polys, threshold=1):
Expand All @@ -1267,12 +1264,12 @@ def _check_polys_length(polys, threshold=1):
stacklevel=2,
)

def _compute_weights_and_area(self, mesh_out):
def _compute_weights_and_area(self, grid_in, mesh_out):
"""Return the weights and the area of the destination mesh cells."""

# Build the regrid object
regrid = esmf_regrid_build(
self.grid_in,
grid_in,
mesh_out,
method='conservative',
ignore_degenerate=self.ignore_degenerate,
Expand All @@ -1286,10 +1283,9 @@ def _compute_weights_and_area(self, mesh_out):
regrid.dstfield.get_area()
dstarea = regrid.dstfield.data.copy()

esmf_regrid_finalize(regrid)
return w, dstarea

def _compute_weights(self):
def _compute_weights(self, grid_in, grid_out):
"""Return weight sparse matrix.
This function first explodes the geometries into a flat list of Polygon exterior objects:
Expand All @@ -1315,14 +1311,16 @@ def _compute_weights(self):
mesh_ext = Mesh.from_polygons(exteriors)

# Get weights for external polygons
w, area = self._compute_weights_and_area(mesh_ext)
w, area = self._compute_weights_and_area(grid_in, mesh_ext)
mesh_ext.destroy() # release mesh memory

# Get weights for interiors and append them to weights from exteriors as a negative contribution.
if len(interiors) > 0 and not self.ignore_holes:
mesh_int = Mesh.from_polygons(interiors)

# Get weights for interiors
w_int, area_int = self._compute_weights_and_area(mesh_int)
w_int, area_int = self._compute_weights_and_area(grid_in, mesh_int)
mesh_int.destroy() # release mesh memory

# Append weights from holes as negative weights
# In sparse >= 0.16, a fill_value of -0.0 is different from 0.0 and the concat would fail
Expand Down Expand Up @@ -1382,10 +1380,3 @@ def _format_xroutput(self, out, new_dims=None):
out.coords[self._lat_out_name] = xr.DataArray(self._lat_out, dims=(self.geom_dim_name,))
out.attrs['regrid_method'] = self.method
return out

def __del__(self):
# Memory leak issue when regridding over a large number of datasets with xESMF
# https://github.com/JiaweiZhuang/xESMF/issues/53
if hasattr(self, 'grid_in'): # If the init has failed, grid_in isn't there
self.grid_in.destroy()
self.grid_out.destroy()

0 comments on commit fceddc8

Please sign in to comment.