Skip to content

Commit

Permalink
Move free_cuda functions to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Nov 26, 2024
1 parent 522d026 commit 9600e7f
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import numpy as np
import xarray as xr
from darts_utils.cuda import free_cupy
from skimage.morphology import binary_erosion, disk, label, remove_small_objects

from darts_postprocessing.utils import free_cuda

logger = logging.getLogger(__name__.replace("darts_", "darts."))

try:
Expand Down Expand Up @@ -64,7 +63,7 @@ def erode_mask(mask: xr.DataArray, size: int, device: Literal["cuda", "cpu"] | i
mask = mask.cupy.as_cupy()
mask.values = binary_erosion_gpu(mask.data, disk_gpu(size))
mask = mask.cupy.as_numpy()
free_cuda()
free_cupy()
else:
mask.values = binary_erosion(mask.values, disk(size))

Expand Down Expand Up @@ -135,7 +134,7 @@ def binarize(
binarized.astype(bool).expand_dims("batch", 0).data, min_size=min_object_size
)[0]
binarized = binarized.cupy.as_numpy()
free_cuda()
free_cupy()
else:
binarized.values = remove_small_objects(
binarized.astype(bool).expand_dims("batch", 0).values, min_size=min_object_size
Expand Down
16 changes: 0 additions & 16 deletions darts-postprocessing/src/darts_postprocessing/utils.py

This file was deleted.

4 changes: 2 additions & 2 deletions darts-preprocessing/src/darts_preprocessing/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import odc.geo.xr # noqa: F401
import xarray as xr
from darts_utils.cuda import free_cupy
from xrspatial.utils import has_cuda_and_cupy

from darts_preprocessing.engineering.arcticdem import calculate_slope, calculate_topographic_position_index
from darts_preprocessing.engineering.indices import calculate_ndvi
from darts_preprocessing.utils import free_cuda

logger = logging.getLogger(__name__.replace("darts_", "darts."))

Expand Down Expand Up @@ -101,7 +101,7 @@ def preprocess_legacy_arcticdem_fast(
ds_arcticdem = calculate_topographic_position_index(ds_arcticdem, tpi_outer_radius, tpi_inner_radius)
ds_arcticdem = calculate_slope(ds_arcticdem)
ds_arcticdem = ds_arcticdem.cupy.as_numpy()
free_cuda()
free_cupy()

# Calculate TPI and slope from ArcticDEM on CPU
else:
Expand Down
16 changes: 0 additions & 16 deletions darts-preprocessing/src/darts_preprocessing/utils.py

This file was deleted.

7 changes: 4 additions & 3 deletions darts-segmentation/src/darts_segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import torch
import torch.nn as nn
import xarray as xr
from darts_utils.cuda import free_torch

from darts_segmentation.utils import free_cuda, predict_in_patches
from darts_segmentation.utils import predict_in_patches

logger = logging.getLogger(__name__.replace("darts_", "darts."))

Expand Down Expand Up @@ -156,7 +157,7 @@ def segment_tile(

# Cleanup cuda memory
del tensor_tile, probabilities
free_cuda()
free_torch()

return tile

Expand Down Expand Up @@ -205,7 +206,7 @@ def segment_tile_batched(

# Cleanup cuda memory
del tensor_tiles, probabilities
free_cuda()
free_torch()

return tiles

Expand Down
7 changes: 0 additions & 7 deletions darts-segmentation/src/darts_segmentation/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Shared utilities for the inference modules."""

import gc
import logging
import math
import time
Expand Down Expand Up @@ -184,9 +183,3 @@ def predict_in_patches(
return prediction, weights
else:
return prediction


def free_cuda():
"""Free the CUDA memory."""
gc.collect()
torch.cuda.empty_cache()
25 changes: 25 additions & 0 deletions darts-utils/src/darts_utils/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Utility functions around cuda, e.g. memory management."""

import gc

import torch

try:
import cupy as cp
except ImportError:
cp = None


def free_cupy():
"""Free the CUDA memory of cupy."""
if cp is not None:
gc.collect()
cp.get_default_memory_pool().free_all_blocks()
cp.get_default_pinned_memory_pool().free_all_blocks()


def free_torch():
"""Free the CUDA memory of pytorch."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

0 comments on commit 9600e7f

Please sign in to comment.