Skip to content

Commit

Permalink
Cube build stage3 - JP-2096 (spacetelescope#6093)
Browse files Browse the repository at this point in the history
* updates using numba jit

* flake 8 fixes

* a few numba updates

* updates to support internal_cal and numba

* fix test - removing unused resolution file

* improved blotting speed using numba

* added c code for emsm

* fixed setup.py  to compile match_det_cube

* updates to c code

* some changes to c python interface

* more fixes to c code

* added cube_match_internal and pulled common c routines to cube_utils.c

* Clean up

* remove cube_cloud.py

* added weighting=msm as possibility for c extension cube weighting

* removed declaration of numba from routine

* fix typo

* flake8 fix

* remove printf from c code

* remove print in ifu_cube.py

* typo in cube_match_sky.c

* changes after review

* fix alloc arrays def

* Updated change log

* remove print statement
  • Loading branch information
jemorrison authored and loicalbert committed Nov 5, 2021
1 parent dc266b1 commit d38a0a7
Show file tree
Hide file tree
Showing 16 changed files with 2,661 additions and 2,116 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ cube_build
----------

- Fix bug when creating cubes using output_type=channel. [#6138]
- Move computationally intensive routines to c extensions and
removed miri psf weight function. [#6093]

datamodels
----------
Expand Down
50 changes: 50 additions & 0 deletions jwst/cube_build/blot_cube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
""" Map the detector pixels to the cube coordinate system.
This is where the weight functions are used.
"""
import numpy as np
import logging
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


def blot_overlap(ipt, xstart,
xcenter, ycenter,
x_cube, y_cube,
flux_cube,
blot_xsize,
blot_flux, blot_weight):

""" Blot the median sky image back to the detector
ipt is the median element.
xcenter, ycenter are the detector pixel arrays
xstart is only valid for MIRI and is the start of the x detector value for channel
0 for channels on left and ~512 for channels on right
x_cube, y_cube: median cube IFU mapped backwards to detector
flux_cube: median flux
blot_flux & blot_weight: blotted values of flux_cube to detector
"""

roi_det = 1.0 # Just large enough that we don't get holes
xdistance = np.absolute(x_cube[ipt] - xcenter)
ydistance = np.absolute(y_cube[ipt] - ycenter)

index_x = np.where(xdistance <= roi_det)
index_y = np.where(ydistance <= roi_det)

if len(index_x[0]) > 0 and len(index_y[0]) > 0:

d1pix = x_cube[ipt] - xcenter[index_x]
d2pix = y_cube[ipt] - ycenter[index_y]

dxy = [(dx * dx + dy * dy) for dy in d2pix for dx in d1pix]
dxy = np.array(dxy)
dxy = np.sqrt(dxy)
weight_distance = np.exp(-dxy)
weighted_flux = weight_distance * flux_cube[ipt]

index2d = [iy * blot_xsize + ix for iy in index_y[0] for ix in (index_x[0] + xstart)]
index2d = np.array(index2d)

blot_flux[index2d] = blot_flux[index2d] + weighted_flux
blot_weight[index2d] = blot_weight[index2d] + weight_distance
74 changes: 23 additions & 51 deletions jwst/cube_build/blot_cube_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from gwcs import wcstools
from . import instrument_defaults

from . import blot_cube

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -199,16 +201,15 @@ def blot_images_miri(self):
# pixel.
# the Regular grid is on the x,y detector

blot_flux = self.blot_overlap_quick(model, xcenter, ycenter,
xstart, x_cube, y_cube,
flux_cube)
blot_flux = self.blot_overlap_miri(model, xcenter, ycenter,
xstart, x_cube, y_cube,
flux_cube)
blot.data = blot_flux
blot_models.append(blot)
return blot_models
# ________________________________________________________________________________

def blot_overlap_quick(self, model, xcenter, ycenter, xstart,
x_cube, y_cube, flux_cube):
def blot_overlap_miri(self, model, xcenter, ycenter, xstart,
x_cube, y_cube, flux_cube):

# blot_overlap_quick finds to overlap between the blotted sky values
# (x_cube, y_cube) and the detector pixels.
Expand All @@ -223,31 +224,16 @@ def blot_overlap_quick(self, model, xcenter, ycenter, xstart,
blot_weight = np.ndarray.flatten(blot_weight)

# loop over valid points in cube (not empty edge pixels)
roi_det = 1.0 # Just large enough that we don't get holes
ivalid = np.nonzero(np.absolute(flux_cube))

ivalid = np.nonzero(np.absolute(flux_cube))
t0 = time.time()
for ipt in ivalid[0]:
# search xcenter and ycenter seperately. These arrays are smallsh.
# xcenter size = naxis1 on detector (for MIRI only 1/2 array)
# ycenter size = naxis2 on detector
xdistance = np.absolute(x_cube[ipt] - xcenter)
ydistance = np.absolute(y_cube[ipt] - ycenter)

index_x = np.where(xdistance <= roi_det)
index_y = np.where(ydistance <= roi_det)

if len(index_x[0]) > 0 and len(index_y[0]) > 0:
d1pix = np.array(x_cube[ipt] - xcenter[index_x])
d2pix = np.array(y_cube[ipt] - ycenter[index_y])

dxy = [(dx * dx + dy * dy) for dy in d2pix for dx in d1pix]
dxy = np.sqrt(dxy)
weight_distance = np.exp(-dxy)
weighted_flux = weight_distance * flux_cube[ipt]
index2d = [iy * blot_xsize + ix for iy in index_y[0] for ix in (index_x[0] + xstart)]
blot_flux[index2d] = blot_flux[index2d] + weighted_flux
blot_weight[index2d] = blot_weight[index2d] + weight_distance
blot_cube.blot_overlap(ipt, xstart,
xcenter, ycenter,
x_cube, y_cube,
flux_cube,
blot_xsize,
blot_flux, blot_weight)

t1 = time.time()
log.debug(f"Time to blot median image to input model = {t1-t0:.1f}")
Expand All @@ -257,7 +243,6 @@ def blot_overlap_quick(self, model, xcenter, ycenter, xstart,
blot_flux = blot_flux.reshape((blot_ysize, blot_xsize))
return blot_flux

# ************************************************************************
def blot_images_nirspec(self):
""" Core blotting routine for NIRSPEC
Expand Down Expand Up @@ -299,7 +284,7 @@ def blot_images_nirspec(self):
nslices = 30
log.info('Looping over 30 slices on NIRSPEC detector, this takes a little while')
t0 = time.time()
roi_det = 1.0 # Just large enough that we don't get holes

for ii in range(nslices):
ts0 = time.time()
# for each slice pull out the blotted values that actually fall on the slice region
Expand Down Expand Up @@ -338,30 +323,17 @@ def blot_images_nirspec(self):
y_slice = y_slice[fuse]
flux_slice = flux_slice[fuse]

xstart = 0
nn = flux_slice.size
for ipt in range(nn):
# search xcenter and ycenter seperately. These arrays are smallish.
# xcenter size = naxis1 on detector
# ycenter size = naxis2 on detector
xdistance = np.absolute(x_slice[ipt] - xcenter)
ydistance = np.absolute(y_slice[ipt] - ycenter)

index_x = np.where(xdistance <= roi_det)
index_y = np.where(ydistance <= roi_det)

if len(index_x[0]) > 0 and len(index_y[0]) > 0:
d1pix = np.array(x_slice[ipt] - xcenter[index_x])
d2pix = np.array(y_slice[ipt] - ycenter[index_y])

dxy = [(dx * dx + dy * dy) for dy in d2pix for dx in d1pix]
dxy = np.sqrt(dxy)
weight_distance = np.exp(-dxy)
weighted_flux = weight_distance * flux_slice[ipt]
index2d = [iy * blot_xsize + ix for iy in index_y[0] for ix in (index_x[0])]
blot_flux[index2d] = blot_flux[index2d] + weighted_flux
blot_weight[index2d] = blot_weight[index2d] + weight_distance
blot_cube.blot_overlap(ipt, xstart,
xcenter, ycenter,
x_slice, y_slice,
flux_slice,
blot_xsize,
blot_flux, blot_weight)
ts1 = time.time()
log.debug(f"Time to map 1 slice = {ts1-ts0:.1f}")
log.debug(f"Time to blot 1 slice on NIRspec = {ts1-ts0:.1f}")
# done mapping median cube to this input model
t1 = time.time()
log.debug(f"Time to blot median image to input model = {t1-t0:.1f}")
Expand Down
14 changes: 0 additions & 14 deletions jwst/cube_build/cube_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(self,
input_models,
input_filenames,
par_filename,
resol_filename,
**pars):
""" Initialize the high level of information for the ifu cube
Expand All @@ -36,15 +35,12 @@ def __init__(self,
list of fits filenames
par_filename: str
cube parameter reference filename
resol_filename: str
miri resolution reference filename
pars : dictionary holding top level cube parameters
"""

self.input_models = input_models
self.input_filenames = input_filenames
self.par_filename = par_filename
self.resol_filename = resol_filename
self.single = pars.get('single')
self.channel = pars.get('channel')
self.subchannel = pars.get('subchannel')
Expand Down Expand Up @@ -73,7 +69,6 @@ def setup(self):
Read in necessary reference data:
* cube parameter reference file
* if miripsf weighting parameter is set then read in resolution file
This routine fills in the instrument_info dictionary, which holds the
default spatial and spectral size of the output cube, as well as,
Expand Down Expand Up @@ -117,15 +112,6 @@ def setup(self):
self.all_filter,
instrument_info)
# -------------------------------------------------------------------------------
# Read the miri resolution reference file
if self.weighting == 'miripsf':
log.info('Reading default MIRI cube resolution file %s',
self.resol_filename)
cube_build_io_util.read_resolution_file(self.resol_filename,
self.all_channel,
self.all_subchannel,
instrument_info)
# _______________________________________________________________________________
self.instrument_info = instrument_info
# _______________________________________________________________________________
# Set up values to return and acess for other parts of cube_build
Expand Down
17 changes: 3 additions & 14 deletions jwst/cube_build/cube_build_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CubeBuildStep (Step):
scale1 = float(default=0.0) # cube sample size to use for axis 1, arc seconds
scale2 = float(default=0.0) # cube sample size to use for axis 2, arc seconds
scalew = float(default=0.0) # cube sample size to use for axis 3, microns
weighting = option('emsm','msm','miripsf',default = 'emsm') # Type of weighting function
weighting = option('emsm','msm',default = 'emsm') # Type of weighting function
coord_system = option('skyalign','world','internal_cal','ifualign',default='skyalign') # Output Coordinate system.
rois = float(default=0.0) # region of interest spatial size, arc seconds
roiw = float(default=0.0) # region of interest wavelength size, microns
Expand All @@ -57,7 +57,7 @@ class CubeBuildStep (Step):
output_use_model = boolean(default=true) # Use filenames in the output models
"""

reference_file_types = ['cubepar', 'resol']
reference_file_types = ['cubepar']

# ________________________________________________________________________________
def process(self, input):
Expand Down Expand Up @@ -140,7 +140,7 @@ def process(self, input):
# if interpolation is point cloud then weighting can be
# 1. MSM: modified shepard method
# 2. EMSM
# 3. miripsf - weighting for MIRI based on PSF and LSF

if self.coord_system == 'skyalign':
self.interpolation = 'pointcloud'

Expand Down Expand Up @@ -221,16 +221,6 @@ def process(self, input):
self.log.warning('No default cube parameters reference file found')
return
# ________________________________________________________________________________
# If miripsf weight is set then set up reference file
resol_filename = None
if self.weighting == 'miripsf':
resol_filename = self.get_reference_file(self.input_models[0], 'resol')
self.log.info(f'MIRI resol reference file {resol_filename}')
if resol_filename == 'N/A':
self.log.warning('No spectral resolution reference file found')
self.log.warning('Run again and turn off miripsf')
return
# ________________________________________________________________________________
# shove the input parameters in to pars to pull out in general cube_build.py

pars = {
Expand Down Expand Up @@ -269,7 +259,6 @@ def process(self, input):
self.input_models,
self.input_filenames,
par_filename,
resol_filename,
**pars)
# ________________________________________________________________________________
# cubeinfo.setup:
Expand Down
26 changes: 9 additions & 17 deletions jwst/cube_build/cube_build_wcs_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,19 @@ def find_corners_MIRI(input, this_channel, instrument_info, coord_system):
lambda_min = np.nanmin(lam)
lambda_max = np.nanmax(lam)

# before returning, ra should be between 0 to 360
if a_min < 0:
a_min = a_min + 360
if a_max >= 360.0:
a_max = a_max - 360.0

if a1 < 0:
a1 = a1 + 360
if a1 > 360.0:
a1 = a1 - 360.0

if a2 < 0:
a2 = a2 + 360
if a2 > 360.0:
a2 = a2 - 360.0
if coord_system != 'internal_cal':
# before returning, ra should be between 0 to 360
a_min %= 360
a_max %= 360

a1 %= 360
a2 %= 360

return a_min, b1, a_max, b2, a1, b_min, a2, b_max, lambda_min, lambda_max
# *****************************************************************************


def find_corners_NIRSPEC(input, this_channel, instrument_info, coord_system):
def find_corners_NIRSPEC(input, instrument_info, coord_system):
"""Find the sky footprint of a slice of a NIRSpec exposure
For each slice find:
Expand Down Expand Up @@ -241,5 +233,5 @@ def find_corners_NIRSPEC(input, this_channel, instrument_info, coord_system):

lambda_min = min(lambda_slice)
lambda_max = max(lambda_slice)

return a_min, b1, a_max, b2, a1, b_min, a2, b_max, lambda_min, lambda_max
# ______________________________________________________________________________
Loading

0 comments on commit d38a0a7

Please sign in to comment.