Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-2096: Cube build stage3 #6093

Merged
merged 25 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
484bfa2
updates using numba jit
jemorrison Jun 1, 2021
b37940e
flake 8 fixes
jemorrison Jun 2, 2021
da1ae3c
a few numba updates
jemorrison Jun 7, 2021
9e230c7
updates to support internal_cal and numba
jemorrison Jun 7, 2021
d1e4ddb
fix test - removing unused resolution file
jemorrison Jun 8, 2021
44e097f
improved blotting speed using numba
jemorrison Jun 11, 2021
3c60bb4
added c code for emsm
jemorrison Jun 25, 2021
da720e8
fixed setup.py to compile match_det_cube
jemorrison Jun 25, 2021
8fb11b3
updates to c code
jemorrison Jul 8, 2021
d45ffa2
some changes to c python interface
jemorrison Jul 8, 2021
3686ecb
more fixes to c code
jemorrison Jul 9, 2021
36f0ef2
added cube_match_internal and pulled common c routines to cube_utils.c
jemorrison Jul 17, 2021
e6bbc2b
Clean up
jemorrison Jul 19, 2021
fabb5e6
remove cube_cloud.py
jemorrison Jul 19, 2021
e8e734a
added weighting=msm as possibility for c extension cube weighting
jemorrison Jul 19, 2021
ec1b3d4
removed declaration of numba from routine
jemorrison Jul 19, 2021
3764bb5
fix typo
jemorrison Jul 19, 2021
17cd929
flake8 fix
jemorrison Jul 19, 2021
80d21e8
remove printf from c code
jemorrison Jul 19, 2021
b01ec6e
remove print in ifu_cube.py
jemorrison Jul 19, 2021
ba41124
typo in cube_match_sky.c
jemorrison Jul 19, 2021
de3f107
changes after review
jemorrison Jul 20, 2021
e506588
fix alloc arrays def
jemorrison Jul 20, 2021
0de907f
Updated change log
jemorrison Jul 21, 2021
b173565
remove print statement
jemorrison Jul 21, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ cube_build
----------

- Fix bug when creating cubes using output_type=channel. [#6138]
- Move computationally intensive routines to c extensions and
removed miri psf weight function. [#6093]

datamodels
----------
Expand Down
50 changes: 50 additions & 0 deletions jwst/cube_build/blot_cube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
""" Map the detector pixels to the cube coordinate system.
This is where the weight functions are used.
"""
import numpy as np
import logging
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


def blot_overlap(ipt, xstart,
xcenter, ycenter,
x_cube, y_cube,
flux_cube,
blot_xsize,
blot_flux, blot_weight):

""" Blot the median sky image back to the detector

ipt is the median element.
xcenter, ycenter are the detector pixel arrays
xstart is only valid for MIRI and is the start of the x detector value for channel
0 for channels on left and ~512 for channels on right
x_cube, y_cube: median cube IFU mapped backwards to detector
flux_cube: median flux
blot_flux & blot_weight: blotted values of flux_cube to detector
"""

roi_det = 1.0 # Just large enough that we don't get holes
xdistance = np.absolute(x_cube[ipt] - xcenter)
ydistance = np.absolute(y_cube[ipt] - ycenter)

index_x = np.where(xdistance <= roi_det)
index_y = np.where(ydistance <= roi_det)
Comment on lines +32 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with the algorithm, but it seems like index_x and index_y, in principle, could have different lengths or point to different "pixels". Would, this be an issue? Especially different lengths in the code just below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to hold off making changes to blot_cube.py because I have another JP ticket to work just on blot cube after I get the c extensions in this PR committed. I will come back to these changes suggestions later this week.


if len(index_x[0]) > 0 and len(index_y[0]) > 0:

d1pix = x_cube[ipt] - xcenter[index_x]
d2pix = y_cube[ipt] - ycenter[index_y]

dxy = [(dx * dx + dy * dy) for dy in d2pix for dx in d1pix]
dxy = np.array(dxy)
dxy = np.sqrt(dxy)
weight_distance = np.exp(-dxy)
Comment on lines +36 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
d1pix = x_cube[ipt] - xcenter[index_x]
d2pix = y_cube[ipt] - ycenter[index_y]
dxy = [(dx * dx + dy * dy) for dy in d2pix for dx in d1pix]
dxy = np.array(dxy)
dxy = np.sqrt(dxy)
weight_distance = np.exp(-dxy)
weight_distance = np.exp(-np.sqrt(np.add.outer(
np.square(x_cube[ipt] - xcenter[index_x]),
np.square(y_cube[ipt] - ycenter[index_y])
).ravel()))

or, alternatively:

Suggested change
d1pix = x_cube[ipt] - xcenter[index_x]
d2pix = y_cube[ipt] - ycenter[index_y]
dxy = [(dx * dx + dy * dy) for dy in d2pix for dx in d1pix]
dxy = np.array(dxy)
dxy = np.sqrt(dxy)
weight_distance = np.exp(-dxy)
weight_distance = np.exp(-np.linalg.norm(np.meshgrid(
y_cube[ipt] - ycenter[index_y],
x_cube[ipt] - xcenter[index_x]
), axis=0).ravel())

weighted_flux = weight_distance * flux_cube[ipt]

index2d = [iy * blot_xsize + ix for iy in index_y[0] for ix in (index_x[0] + xstart)]
index2d = np.array(index2d)
Comment on lines +46 to +47
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
index2d = [iy * blot_xsize + ix for iy in index_y[0] for ix in (index_x[0] + xstart)]
index2d = np.array(index2d)
index2d = np.add.outer(index_y[0] * blot_xsize, index_x[0] + xstart).ravel()


blot_flux[index2d] = blot_flux[index2d] + weighted_flux
blot_weight[index2d] = blot_weight[index2d] + weight_distance
74 changes: 23 additions & 51 deletions jwst/cube_build/blot_cube_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from gwcs import wcstools
from . import instrument_defaults

from . import blot_cube

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -199,16 +201,15 @@ def blot_images_miri(self):
# pixel.
# the Regular grid is on the x,y detector

blot_flux = self.blot_overlap_quick(model, xcenter, ycenter,
xstart, x_cube, y_cube,
flux_cube)
blot_flux = self.blot_overlap_miri(model, xcenter, ycenter,
xstart, x_cube, y_cube,
flux_cube)
blot.data = blot_flux
blot_models.append(blot)
return blot_models
# ________________________________________________________________________________

def blot_overlap_quick(self, model, xcenter, ycenter, xstart,
x_cube, y_cube, flux_cube):
def blot_overlap_miri(self, model, xcenter, ycenter, xstart,
x_cube, y_cube, flux_cube):

# blot_overlap_quick finds to overlap between the blotted sky values
# (x_cube, y_cube) and the detector pixels.
Expand All @@ -223,31 +224,16 @@ def blot_overlap_quick(self, model, xcenter, ycenter, xstart,
blot_weight = np.ndarray.flatten(blot_weight)

# loop over valid points in cube (not empty edge pixels)
roi_det = 1.0 # Just large enough that we don't get holes
ivalid = np.nonzero(np.absolute(flux_cube))

ivalid = np.nonzero(np.absolute(flux_cube))
t0 = time.time()
for ipt in ivalid[0]:
# search xcenter and ycenter seperately. These arrays are smallsh.
# xcenter size = naxis1 on detector (for MIRI only 1/2 array)
# ycenter size = naxis2 on detector
xdistance = np.absolute(x_cube[ipt] - xcenter)
ydistance = np.absolute(y_cube[ipt] - ycenter)

index_x = np.where(xdistance <= roi_det)
index_y = np.where(ydistance <= roi_det)

if len(index_x[0]) > 0 and len(index_y[0]) > 0:
d1pix = np.array(x_cube[ipt] - xcenter[index_x])
d2pix = np.array(y_cube[ipt] - ycenter[index_y])

dxy = [(dx * dx + dy * dy) for dy in d2pix for dx in d1pix]
dxy = np.sqrt(dxy)
weight_distance = np.exp(-dxy)
weighted_flux = weight_distance * flux_cube[ipt]
index2d = [iy * blot_xsize + ix for iy in index_y[0] for ix in (index_x[0] + xstart)]
blot_flux[index2d] = blot_flux[index2d] + weighted_flux
blot_weight[index2d] = blot_weight[index2d] + weight_distance
blot_cube.blot_overlap(ipt, xstart,
xcenter, ycenter,
x_cube, y_cube,
flux_cube,
blot_xsize,
blot_flux, blot_weight)

t1 = time.time()
log.debug(f"Time to blot median image to input model = {t1-t0:.1f}")
Expand All @@ -257,7 +243,6 @@ def blot_overlap_quick(self, model, xcenter, ycenter, xstart,
blot_flux = blot_flux.reshape((blot_ysize, blot_xsize))
return blot_flux

# ************************************************************************
def blot_images_nirspec(self):
""" Core blotting routine for NIRSPEC

Expand Down Expand Up @@ -299,7 +284,7 @@ def blot_images_nirspec(self):
nslices = 30
log.info('Looping over 30 slices on NIRSPEC detector, this takes a little while')
t0 = time.time()
roi_det = 1.0 # Just large enough that we don't get holes

for ii in range(nslices):
ts0 = time.time()
# for each slice pull out the blotted values that actually fall on the slice region
Expand Down Expand Up @@ -338,30 +323,17 @@ def blot_images_nirspec(self):
y_slice = y_slice[fuse]
flux_slice = flux_slice[fuse]

xstart = 0
nn = flux_slice.size
for ipt in range(nn):
# search xcenter and ycenter seperately. These arrays are smallish.
# xcenter size = naxis1 on detector
# ycenter size = naxis2 on detector
xdistance = np.absolute(x_slice[ipt] - xcenter)
ydistance = np.absolute(y_slice[ipt] - ycenter)

index_x = np.where(xdistance <= roi_det)
index_y = np.where(ydistance <= roi_det)

if len(index_x[0]) > 0 and len(index_y[0]) > 0:
d1pix = np.array(x_slice[ipt] - xcenter[index_x])
d2pix = np.array(y_slice[ipt] - ycenter[index_y])

dxy = [(dx * dx + dy * dy) for dy in d2pix for dx in d1pix]
dxy = np.sqrt(dxy)
weight_distance = np.exp(-dxy)
weighted_flux = weight_distance * flux_slice[ipt]
index2d = [iy * blot_xsize + ix for iy in index_y[0] for ix in (index_x[0])]
blot_flux[index2d] = blot_flux[index2d] + weighted_flux
blot_weight[index2d] = blot_weight[index2d] + weight_distance
blot_cube.blot_overlap(ipt, xstart,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just set xstart to 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

holding off on blot changes - I have opened a separate ticket on blotting

xcenter, ycenter,
x_slice, y_slice,
flux_slice,
blot_xsize,
blot_flux, blot_weight)
ts1 = time.time()
log.debug(f"Time to map 1 slice = {ts1-ts0:.1f}")
log.debug(f"Time to blot 1 slice on NIRspec = {ts1-ts0:.1f}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is timing for one slice relevant even for debugging purposes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is for NIRSPEC. It can take several seconds per slice. Once we get it faster I will remove the debug timing

# done mapping median cube to this input model
t1 = time.time()
log.debug(f"Time to blot median image to input model = {t1-t0:.1f}")
Expand Down
14 changes: 0 additions & 14 deletions jwst/cube_build/cube_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(self,
input_models,
input_filenames,
par_filename,
resol_filename,
**pars):
""" Initialize the high level of information for the ifu cube

Expand All @@ -36,15 +35,12 @@ def __init__(self,
list of fits filenames
par_filename: str
cube parameter reference filename
resol_filename: str
miri resolution reference filename
pars : dictionary holding top level cube parameters
"""

self.input_models = input_models
self.input_filenames = input_filenames
self.par_filename = par_filename
self.resol_filename = resol_filename
self.single = pars.get('single')
self.channel = pars.get('channel')
self.subchannel = pars.get('subchannel')
Expand Down Expand Up @@ -73,7 +69,6 @@ def setup(self):

Read in necessary reference data:
* cube parameter reference file
* if miripsf weighting parameter is set then read in resolution file

This routine fills in the instrument_info dictionary, which holds the
default spatial and spectral size of the output cube, as well as,
Expand Down Expand Up @@ -117,15 +112,6 @@ def setup(self):
self.all_filter,
instrument_info)
# -------------------------------------------------------------------------------
# Read the miri resolution reference file
if self.weighting == 'miripsf':
log.info('Reading default MIRI cube resolution file %s',
self.resol_filename)
cube_build_io_util.read_resolution_file(self.resol_filename,
self.all_channel,
self.all_subchannel,
instrument_info)
# _______________________________________________________________________________
self.instrument_info = instrument_info
# _______________________________________________________________________________
# Set up values to return and acess for other parts of cube_build
Expand Down
17 changes: 3 additions & 14 deletions jwst/cube_build/cube_build_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CubeBuildStep (Step):
scale1 = float(default=0.0) # cube sample size to use for axis 1, arc seconds
scale2 = float(default=0.0) # cube sample size to use for axis 2, arc seconds
scalew = float(default=0.0) # cube sample size to use for axis 3, microns
weighting = option('emsm','msm','miripsf',default = 'emsm') # Type of weighting function
weighting = option('emsm','msm',default = 'emsm') # Type of weighting function
coord_system = option('skyalign','world','internal_cal','ifualign',default='skyalign') # Output Coordinate system.
rois = float(default=0.0) # region of interest spatial size, arc seconds
roiw = float(default=0.0) # region of interest wavelength size, microns
Expand All @@ -57,7 +57,7 @@ class CubeBuildStep (Step):
output_use_model = boolean(default=true) # Use filenames in the output models
"""

reference_file_types = ['cubepar', 'resol']
reference_file_types = ['cubepar']

# ________________________________________________________________________________
def process(self, input):
Expand Down Expand Up @@ -140,7 +140,7 @@ def process(self, input):
# if interpolation is point cloud then weighting can be
# 1. MSM: modified shepard method
# 2. EMSM
# 3. miripsf - weighting for MIRI based on PSF and LSF

if self.coord_system == 'skyalign':
self.interpolation = 'pointcloud'

Expand Down Expand Up @@ -221,16 +221,6 @@ def process(self, input):
self.log.warning('No default cube parameters reference file found')
return
# ________________________________________________________________________________
# If miripsf weight is set then set up reference file
resol_filename = None
if self.weighting == 'miripsf':
resol_filename = self.get_reference_file(self.input_models[0], 'resol')
self.log.info(f'MIRI resol reference file {resol_filename}')
if resol_filename == 'N/A':
self.log.warning('No spectral resolution reference file found')
self.log.warning('Run again and turn off miripsf')
return
# ________________________________________________________________________________
# shove the input parameters in to pars to pull out in general cube_build.py

pars = {
Expand Down Expand Up @@ -269,7 +259,6 @@ def process(self, input):
self.input_models,
self.input_filenames,
par_filename,
resol_filename,
**pars)
# ________________________________________________________________________________
# cubeinfo.setup:
Expand Down
26 changes: 9 additions & 17 deletions jwst/cube_build/cube_build_wcs_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,19 @@ def find_corners_MIRI(input, this_channel, instrument_info, coord_system):
lambda_min = np.nanmin(lam)
lambda_max = np.nanmax(lam)

# before returning, ra should be between 0 to 360
if a_min < 0:
a_min = a_min + 360
if a_max >= 360.0:
a_max = a_max - 360.0

if a1 < 0:
a1 = a1 + 360
if a1 > 360.0:
a1 = a1 - 360.0

if a2 < 0:
a2 = a2 + 360
if a2 > 360.0:
a2 = a2 - 360.0
if coord_system != 'internal_cal':
# before returning, ra should be between 0 to 360
a_min %= 360
a_max %= 360

a1 %= 360
a2 %= 360

return a_min, b1, a_max, b2, a1, b_min, a2, b_max, lambda_min, lambda_max
# *****************************************************************************


def find_corners_NIRSPEC(input, this_channel, instrument_info, coord_system):
def find_corners_NIRSPEC(input, instrument_info, coord_system):
"""Find the sky footprint of a slice of a NIRSpec exposure

For each slice find:
Expand Down Expand Up @@ -241,5 +233,5 @@ def find_corners_NIRSPEC(input, this_channel, instrument_info, coord_system):

lambda_min = min(lambda_slice)
lambda_max = max(lambda_slice)

return a_min, b1, a_max, b2, a1, b_min, a2, b_max, lambda_min, lambda_max
# ______________________________________________________________________________
Loading