Skip to content

Commit

Permalink
accelerate pc_pyramid by using handwritten data downsampling and writ…
Browse files Browse the repository at this point in the history
…ing code
  • Loading branch information
kanglcn committed Aug 16, 2024
1 parent 3c6e9b3 commit e151d4b
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 494 deletions.
7 changes: 2 additions & 5 deletions moraine/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,18 @@
'moraine.cli.plot._is_nan_range': ('CLI/plot.html#_is_nan_range', 'moraine/cli/plot.py'),
'moraine.cli.plot._next_level_idx_from_raster_of_integer': ( 'CLI/plot.html#_next_level_idx_from_raster_of_integer',
'moraine/cli/plot.py'),
'moraine.cli.plot._next_level_idx_from_raster_of_noninteger': ( 'CLI/plot.html#_next_level_idx_from_raster_of_noninteger',
'moraine/cli/plot.py'),
'moraine.cli.plot._next_ras': ('CLI/plot.html#_next_ras', 'moraine/cli/plot.py'),
'moraine.cli.plot._pc_downsample_all_and_save': ( 'CLI/plot.html#_pc_downsample_all_and_save',
'moraine/cli/plot.py'),
'moraine.cli.plot._pc_inf_0_post_proc': ('CLI/plot.html#_pc_inf_0_post_proc', 'moraine/cli/plot.py'),
'moraine.cli.plot._pc_inf_all_post_proc': ('CLI/plot.html#_pc_inf_all_post_proc', 'moraine/cli/plot.py'),
'moraine.cli.plot._pc_inf_seq_post_proc': ('CLI/plot.html#_pc_inf_seq_post_proc', 'moraine/cli/plot.py'),
'moraine.cli.plot._ras_downsample': ('CLI/plot.html#_ras_downsample', 'moraine/cli/plot.py'),
'moraine.cli.plot._ras_downsample_all_and_save': ( 'CLI/plot.html#_ras_downsample_all_and_save',
'moraine/cli/plot.py'),
'moraine.cli.plot._ras_inf_0_post_proc': ('CLI/plot.html#_ras_inf_0_post_proc', 'moraine/cli/plot.py'),
'moraine.cli.plot._ras_inf_all_post_proc': ( 'CLI/plot.html#_ras_inf_all_post_proc',
'moraine/cli/plot.py'),
'moraine.cli.plot._ras_inf_seq_post_proc': ( 'CLI/plot.html#_ras_inf_seq_post_proc',
'moraine/cli/plot.py'),
'moraine.cli.plot._zarr_stack_info': ('CLI/plot.html#_zarr_stack_info', 'moraine/cli/plot.py'),
'moraine.cli.plot.pc_plot': ('CLI/plot.html#pc_plot', 'moraine/cli/plot.py'),
'moraine.cli.plot.pc_pyramid': ('CLI/plot.html#pc_pyramid', 'moraine/cli/plot.py'),
'moraine.cli.plot.ras_plot': ('CLI/plot.html#ras_plot', 'moraine/cli/plot.py'),
Expand Down
196 changes: 89 additions & 107 deletions moraine/cli/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
import numpy as np
import math
from pathlib import Path
import shutil
import pandas as pd
from tqdm import tqdm
import sys
from functools import partial
from typing import Callable
import numpy as np
Expand All @@ -31,26 +27,9 @@
from ..rtree import HilbertRtree
from .logging import mc_logger
from ..coord_ import Coord
from . import mk_clean_dir, dask_from_zarr, dask_to_zarr, parallel_write_zarr
from . import mk_clean_dir, dask_from_zarr, dask_to_zarr, parallel_write_zarr, parallel_read_zarr

# %% ../../nbs/CLI/plot.ipynb 5
def _zarr_stack_info(
zarr_path_list, #list of zarr path
):
shape_list = []; chunks_list = []; dtype_list = []
for zarr_path in zarr_path_list:
zarr_data = zarr.open(zarr_path,'r')
shape_list.append(zarr_data.shape)
chunks_list.append(zarr_data.chunks)
dtype_list.append(zarr_data.dtype)
df = pd.DataFrame({'path':zarr_path_list,'shape':shape_list,'chunks':chunks_list,'dtype':dtype_list})
return df

# %% ../../nbs/CLI/plot.ipynb 7
def _ras_downsample(ras,down_level=1):
return ras[::2**down_level,::2**down_level]

# %% ../../nbs/CLI/plot.ipynb 8
# %% ../../nbs/CLI/plot.ipynb 6
def _ras_downsample_all_and_save(ras,zarrs,channel_idx):
slices = [slice(None),slice(None)]
if len(channel_idx) != 0:
Expand All @@ -62,7 +41,7 @@ def _ras_downsample_all_and_save(ras,zarrs,channel_idx):
ras_ = ras[::2**level,::2**level]
parallel_write_zarr(ras_,zarrs[level],slices)

# %% ../../nbs/CLI/plot.ipynb 9
# %% ../../nbs/CLI/plot.ipynb 7
@mc_logger
def ras_pyramid(
ras:str, # path to input data, 2D zarr array (one single raster) or 3D zarr array (a stack of rasters)
Expand All @@ -75,6 +54,7 @@ def ras_pyramid(
):
'''render raster data to pyramid of difference zoom levels.'''
logger = logging.getLogger(__name__)
logger.info('clean out dir')
out_dir = Path(out_dir); mk_clean_dir(out_dir)

ras_zarr = zarr.open(ras,'r')
Expand Down Expand Up @@ -127,7 +107,7 @@ def ras_pyramid(
logger.info('computing finished.')
logger.info('dask cluster closed.')

# %% ../../nbs/CLI/plot.ipynb 15
# %% ../../nbs/CLI/plot.ipynb 13
# there should be better way to achieve variable kdims, but I don't find that.
def _hv_ras_callback_0(x_range,y_range,width,height,scale,data_dir,post_proc,coord,level_increase):
# start = time.time()
Expand Down Expand Up @@ -199,7 +179,7 @@ def _hv_ras_callback_2(x_range,y_range,width,height,scale,data_dir,post_proc,coo
data = post_proc(data_zarr,slice(xi0,xim+1),slice(yi0,yim+1),i,j)
return hv.Image(data[::-1,:],bounds=coord_bbox)

# %% ../../nbs/CLI/plot.ipynb 16
# %% ../../nbs/CLI/plot.ipynb 14
def _default_ras_post_proc(data_zarr, xslice, yslice, *kdims):
data_n_kdim = data_zarr.ndim - 2
assert len(kdims) == data_n_kdim
Expand All @@ -209,7 +189,7 @@ def _default_ras_post_proc(data_zarr, xslice, yslice, *kdims):
else:
return data_zarr[yslice,xslice,kdims]

# %% ../../nbs/CLI/plot.ipynb 17
# %% ../../nbs/CLI/plot.ipynb 15
def _ras_inf_0_post_proc(data_zarr, xslice, yslice, *kdims):
data_n_kdim = data_zarr.ndim - 2
assert len(kdims) == 1
Expand Down Expand Up @@ -257,7 +237,7 @@ def _ras_inf_all_post_proc(data_zarr, xslice, yslice, *kdims):
else:
return data_zarr[yslice,xslice,i,j]

# %% ../../nbs/CLI/plot.ipynb 18
# %% ../../nbs/CLI/plot.ipynb 16
def ras_plot(
pyramid_dir:str, # directory to the rendered ras pyramid
post_proc:Callable=None, # function for the post processing, can be None, 'intf_0', 'intf_seq', 'intf_all' or user-defined function
Expand Down Expand Up @@ -305,7 +285,7 @@ def ras_plot(
post_proc=post_proc,coord=coord,level_increase=level_increase),streams=[rangexy,plotsize],kdims=kdims)
return images

# %% ../../nbs/CLI/plot.ipynb 52
# %% ../../nbs/CLI/plot.ipynb 50
@ngpjit
def _next_level_idx_from_raster_of_integer(pc_idx, nan_value):
'''return the raster indices to the next level of raster'''
Expand All @@ -328,35 +308,26 @@ def _next_level_idx_from_raster_of_integer(pc_idx, nan_value):
xi[i,j] = idx_[0,1] + j*2
return yi, xi

# %% ../../nbs/CLI/plot.ipynb 53
# currently not used
@ngpjit
def _next_level_idx_from_raster_of_noninteger(pc_data):
'''return the raster indices to the next level of raster'''
assert pc_data.ndim == 2
ny, nx = pc_data.shape
next_ny, next_nx = math.ceil(ny/2), math.ceil(nx/2)
xi = np.empty((next_ny,next_nx), dtype=np.int32)
yi = np.empty((next_ny,next_nx), dtype=np.int32)
# %% ../../nbs/CLI/plot.ipynb 52
def _pc_downsample_all_and_save(pc,coord,gix,yis,xis,pc_zarr,ras_zarrs,channel_idx):
pc_slices = [slice(None),]
ras_slices = [slice(None),slice(None)]
if len(channel_idx) != 0:
for idx in channel_idx:
pc_slices.append(slice(idx,idx+1))
ras_slices.append(slice(idx,idx+1))
pc_slices = tuple(pc_slices)
ras_slices = tuple(ras_slices)
parallel_write_zarr(pc,pc_zarr,pc_slices)

for i in range(next_ny):
for j in prange(next_nx):
# Select a 2x2 box from the original array
box = pc_data[i*2:min(i*2+2, ny), j*2:min(j*2+2, nx)]
idx_ = np.argwhere(~np.isnan(box))
if len(idx_) == 0:
yi[i,j]= i*2
xi[i,j] = j*2
else:
yi[i,j] = idx_[0,0] + i*2
xi[i,j] = idx_[0,1] + j*2
return yi, xi
ras = coord.rasterize(pc,gix)
parallel_write_zarr(ras,ras_zarrs[0],ras_slices)

# %% ../../nbs/CLI/plot.ipynb 55
def _next_ras(ras,yi,xi):
return ras[yi,xi]
for level in range(1,len(ras_zarrs)):
ras = ras[yis[level-1],xis[level-1]]
parallel_write_zarr(ras,ras_zarrs[level],ras_slices)

# %% ../../nbs/CLI/plot.ipynb 56
# %% ../../nbs/CLI/plot.ipynb 53
@mc_logger
def pc_pyramid(
pc:str, # path to point cloud data, 1D array (one single pc image) or 2D zarr array (a stack of pc images)
Expand All @@ -374,6 +345,7 @@ def pc_pyramid(
):
'''render point cloud data to pyramid of difference zoom levels.'''
logger = logging.getLogger(__name__)
logger.info('clean out dir')
out_dir = Path(out_dir); mk_clean_dir(out_dir)

pc_zarr = zarr.open(pc,'r')
Expand All @@ -383,12 +355,14 @@ def pc_pyramid(
channel_chunks = (1,)*(pc_zarr.ndim-1)
logger.info(f'rendering point cloud data coordinates:')
if x is None and y is None:
yx = zarr.open(yx,'r')[:]
yx_zarr = zarr.open(yx,'r')
assert yx_zarr.shape[1] == 2
yx = parallel_read_zarr(yx_zarr,(slice(None),slice(0,2)))
else:
y_zarr = zarr.open(y,'r')
yx = np.empty((y_zarr.shape[0],2),dtype=y_zarr.dtype)
yx[:,0] = zarr.open(y,'r')[:]
yx[:,1] = zarr.open(x,'r')[:]
yx[:,0] = parallel_read_zarr(zarr.open(y,'r'),(slice(None),))
yx[:,1] = parallel_read_zarr(zarr.open(x,'r'),(slice(None),))
x, y = yx[:,1], yx[:,0]

x0, xm, y0, ym = x.min(), x.max(), y.min(), y.max()
Expand All @@ -402,64 +376,72 @@ def pc_pyramid(
gix = coord.coords2gixs(yx)
maxlevel = coord.maxlevel

out_x_zarr = zarr.open(out_dir/f'x.zarr','w',shape=x.shape,dtype=x.dtype,chunks=(pc_chunks,))
out_y_zarr = zarr.open(out_dir/f'y.zarr','w',shape=y.shape,dtype=y.dtype,chunks=(pc_chunks,))
logger.zarr_info(out_dir/f'x.zarr',out_x_zarr)
logger.zarr_info(out_dir/f'y.zarr',out_y_zarr)
parallel_write_zarr(x, out_x_zarr,(slice(None),))
parallel_write_zarr(y, out_y_zarr,(slice(None),))
del x, y, yx
logger.info('pc data coordinates rendering ends.')

yis = []; xis = []
for level in range(maxlevel+1):
if level == 0:
current_idx = coord.rasterize_iidx(gix)
else:
yi, xi = _next_level_idx_from_raster_of_integer(last_idx,-1)
yis.append(yi); xis.append(xi)
current_idx = last_idx[yi,xi]
out_idx_zarr = zarr.open(out_dir/f'idx_{level}.zarr',shape=current_idx.shape,dtype=current_idx.dtype,chunks=ras_chunks)
logger.zarr_info(out_dir/f'idx_{level}.zarr',out_idx_zarr)
parallel_write_zarr(current_idx,out_idx_zarr,(slice(None),slice(None)))
last_idx = current_idx
logger.info('rasterized idx rendering ends')

with LocalCluster(processes=processes,
n_workers=n_workers,
threads_per_worker=threads_per_worker,
**dask_cluster_arg) as cluster, Client(cluster) as client:
logger.info('dask local cluster started.')
logger.info('dask local cluster started to render pc data.')
logger.dask_cluster_info(cluster)
output_futures = []
x_darr, y_darr = da.from_array(x,chunks=pc_chunks), da.from_array(y,chunks=pc_chunks)
output_futures.append(da.to_zarr(x_darr, out_dir/f'x.zarr', compute=False, overwrite=True))
output_futures.append(da.to_zarr(y_darr, out_dir/f'y.zarr', compute=False, overwrite=True))
logger.info('pc data coordinates rendering ends.')

pc_darr = dask_from_zarr(pc,parallel_dims=0)
pc_darr = pc_darr.rechunk((n_pc,*channel_chunks))
#pc_darr = da.from_zarr(pc,chunks=(n_pc,*channel_chunks),inline_array=True)
#out_pc_darr = pc_darr.rechunk((pc_chunks,*channel_chunks))
#output_futures.append(da.to_zarr(out_pc_darr, out_dir/f'pc.zarr', compute=False, overwrite=True))
output_futures.append(dask_to_zarr(pc_darr, out_dir/f'pc.zarr', chunks=(pc_chunks,*channel_chunks)))
logger.info('pc data rendering ends.')

delayed_next_idx = delayed(_next_level_idx_from_raster_of_integer,pure=True,nout=2)

out_pc_zarr = zarr.open(out_dir/f'pc.zarr','w',shape=pc_zarr.shape, dtype=pc_zarr.dtype, chunks=(pc_chunks,*channel_chunks))
logger.zarr_info(out_dir/f'pc.zarr', out_pc_zarr)

downsampled_ras_zarrs = []
for level in range(maxlevel+1):
if level == 0:
current_ras = pc_darr.map_blocks(coord.rasterize, gix, dtype=pc_darr.dtype, chunks=(ny,nx,*channel_chunks))
current_idx = da.from_array(coord.rasterize_iidx(gix), chunks=(ny,nx))
else:
last_idx_delayed = last_idx.to_delayed()
yi, xi = np.empty((1,1),dtype=object), np.empty((1,1),dtype=object)
yi_, xi_ = delayed_next_idx(last_idx_delayed[0,0],-1)
shape = (math.ceil(ny/(2**level)), math.ceil(nx/(2**level)))
yi_ = da.from_delayed(yi_,shape=shape,meta=np.array((),dtype=np.int32))
xi_ = da.from_delayed(xi_,shape=shape,meta=np.array((),dtype=np.int32))
yi[0,0] = yi_; xi[0,0] = xi_
yi, xi = da.block(yi.tolist()), da.block(xi.tolist())

current_ras = last_ras.map_blocks(_next_ras, yi, xi, dtype=last_ras.dtype, chunks=(*shape, *channel_chunks))
current_idx = last_idx.map_blocks(_next_ras, yi, xi, dtype=last_idx.dtype, chunks=shape)

# out_current_ras = current_ras.rechunk((*ras_chunks, *channel_chunks))
# out_current_idx = current_idx.rechunk(ras_chunks)
logger.darr_info(f'rasterized pc data at level {level}', current_ras)
logger.darr_info(f'rasterized pc index at level {level}', current_idx)
output_futures.append(dask_to_zarr(current_ras, out_dir/f'{level}.zarr', chunks=(*ras_chunks, *channel_chunks)))
output_futures.append(dask_to_zarr(current_idx, out_dir/f'idx_{level}.zarr', chunks=ras_chunks))
#output_futures.append(da.to_zarr(out_current_ras, out_dir/f'{level}.zarr', compute=False, overwrite=True))
#output_futures.append(da.to_zarr(out_current_idx, out_dir/f'idx_{level}.zarr', compute=False, overwrite=True))
last_ras = current_ras
last_idx = current_idx
shape = (math.ceil(ny/(2**level)), math.ceil(nx/(2**level)))
downsampled_ras_store = zarr.NestedDirectoryStore(out_dir/f'{level}.zarr')
downsampled_ras_zarr = zarr.open(
downsampled_ras_store,'w',
shape=(*shape,*pc_zarr.shape[1:]),
dtype=pc_zarr.dtype,
chunks=(*ras_chunks,*channel_chunks),)
logger.zarr_info(out_dir/f'{level}.zarr',downsampled_ras_zarr)
downsampled_ras_zarrs.append(downsampled_ras_zarr)

pc_darr = dask_from_zarr(pc,chunks=(n_pc,*channel_chunks))
pc_delayed = pc_darr.to_delayed().reshape(pc_zarr.shape[1:])
out_delayed = np.empty_like(pc_delayed,dtype=object)
downsample_save_delayed = delayed(_pc_downsample_all_and_save,pure=True,nout=0)

with np.nditer(out_delayed,flags=['multi_index','refs_ok'], op_flags=['readwrite']) as arr_it:
for arr_block in arr_it:
channel_idx = arr_it.multi_index
out_delayed[channel_idx] = downsample_save_delayed(pc_delayed[channel_idx],coord,gix, yis, xis, out_pc_zarr, downsampled_ras_zarrs,channel_idx)
out_delayed[channel_idx] = da.from_delayed(out_delayed[channel_idx],shape=(1,),dtype=int)
out = da.block(out_delayed.tolist())

logger.info('computing graph setted. doing all the computing.')
futures = client.persist(output_futures)
futures = client.persist(out)
progress(futures,notebook=False)
time.sleep(0.1)
da.compute(futures)
logger.info('computing finished.')
logger.info('dask cluster closed.')

# %% ../../nbs/CLI/plot.ipynb 61
# %% ../../nbs/CLI/plot.ipynb 58
def _is_nan_range(x_range):
if x_range is None:
return True
Expand All @@ -469,7 +451,7 @@ def _is_nan_range(x_range):
return True
return False

# %% ../../nbs/CLI/plot.ipynb 62
# %% ../../nbs/CLI/plot.ipynb 59
def _hv_pc_Image_callback_0(x_range,y_range,width,height,scale,data_dir,post_proc_ras,coord,level_increase):
if _is_nan_range(x_range):
x0 = coord.x0; xm = coord.xm
Expand Down Expand Up @@ -553,7 +535,7 @@ def _hv_pc_Image_callback_2(x_range,y_range,width,height,scale,data_dir,post_pro
else:
return hv.Image([],vdims=['z','idx'])

# %% ../../nbs/CLI/plot.ipynb 63
# %% ../../nbs/CLI/plot.ipynb 60
def _hv_pc_Points_callback_0(x_range,y_range,width,height,scale,data_dir,post_proc_pc,coord,rtree,level_increase):
if _is_nan_range(x_range):
x0 = coord.x0; xm = coord.xm
Expand Down Expand Up @@ -632,7 +614,7 @@ def _hv_pc_Points_callback_2(x_range,y_range,width,height,scale,data_dir,post_pr
data = post_proc_pc(data_zarr,idx,i,j)
return hv.Points((x,y,data,idx),vdims=['z','idx'])

# %% ../../nbs/CLI/plot.ipynb 64
# %% ../../nbs/CLI/plot.ipynb 61
def _default_pc_post_proc(data_zarr, idx_array, *kdims):
data_n_kdim = data_zarr.ndim - 1
assert len(kdims) == data_n_kdim
Expand All @@ -641,7 +623,7 @@ def _default_pc_post_proc(data_zarr, idx_array, *kdims):
else:
return data_zarr[idx_array,kdims]

# %% ../../nbs/CLI/plot.ipynb 65
# %% ../../nbs/CLI/plot.ipynb 62
def _pc_inf_0_post_proc(data_zarr, idx_array, *kdims):
data_n_kdim = data_zarr.ndim - 1
assert len(kdims) == 1
Expand Down Expand Up @@ -690,7 +672,7 @@ def _pc_inf_all_post_proc(data_zarr, idx_array, *kdims):
else:
return data_zarr[idx_array,i,j]

# %% ../../nbs/CLI/plot.ipynb 66
# %% ../../nbs/CLI/plot.ipynb 63
def pc_plot(
pyramid_dir:str, # directory to the rendered point cloud pyramid
post_proc_ras:Callable=None, # function for the post processing
Expand Down
Loading

0 comments on commit e151d4b

Please sign in to comment.