Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up NRS IFU footprint computation #5969

Merged
merged 6 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ assign_wcs
validate_open_slits function, so a proper error message is provided to
the user [#5939]

- Added computed ``spectral_region`` to ``model.meta.wcsinfo``. [#5969]

associations
------------

Expand Down
63 changes: 45 additions & 18 deletions jwst/assign_wcs/nirspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,37 +1554,29 @@ def gwa_to_ymsa(msa2gwa_model, lam_cen=None, slit=None, slit_y_range=None):
return tab


def nrs_wcs_set_input(input_model, slit_name, wavelength_range=None):
def _nrs_wcs_set_input(input_model, slit_name):
"""
Returns a WCS object for a specific slit, slice or shutter.
Does not compute the bounding box.

Parameters
----------
input_model : `~jwst.datamodels.DataModel`
A WCS object for the all open slitlets in an observation.
slit_name : int or str
Slit.name of an open slit.
wavelength_range: list
Wavelength range for the combination of fliter and grating.

Returns
-------
wcsobj : `~gwcs.wcs.WCS`
WCS object for this slit.
"""
import copy # TODO: Add a copy method to gwcs.WCS
import copy
wcsobj = input_model.meta.wcs
if wavelength_range is None:
_, wrange = spectral_order_wrange_from_model(input_model)
else:
wrange = wavelength_range

slit_wcs = copy.deepcopy(wcsobj)
slit_wcs.set_transform('sca', 'gwa', wcsobj.pipeline[1].transform[1:])
# get the open slits from the model
# Need them to get the slit ymin,ymax
g2s = wcsobj.pipeline[2].transform
open_slits = g2s.slits

g2s = slit_wcs.pipeline[2].transform
slit_wcs.set_transform('gwa', 'slit_frame', g2s.get_model(slit_name))

exp_type = input_model.meta.exposure.type
Expand All @@ -1595,14 +1587,49 @@ def nrs_wcs_set_input(input_model, slit_name, wavelength_range=None):
else:
slit_wcs.set_transform('slit_frame', 'msa_frame',
wcsobj.pipeline[3].transform.get_model(slit_name) & Identity(1))
slit2detector = slit_wcs.get_transform('slit_frame', 'detector')
return slit_wcs


def nrs_wcs_set_input(input_model, slit_name, wavelength_range=None,
slit_y_low=None, slit_y_high=None):
"""
Returns a WCS object for a specific slit, slice or shutter.

Parameters
----------
input_model : `~jwst.datamodels.DataModel`
A WCS object for the all open slitlets in an observation.
slit_name : int or str
Slit.name of an open slit.
wavelength_range: list
Wavelength range for the combination of filter and grating.

Returns
-------
wcsobj : `~gwcs.wcs.WCS`
WCS object for this slit.
"""
def _get_y_range(input_model):
# get the open slits from the model
# Need them to get the slit ymin,ymax
g2s = input_model.meta.wcs.get_transform('gwa', 'slit_frame')
open_slits = g2s.slits
slit = [s for s in open_slits if s.name == slit_name][0]
return slit.ymin, slit.ymax

if wavelength_range is None:
_, wavelength_range = spectral_order_wrange_from_model(input_model)

slit_wcs = _nrs_wcs_set_input(input_model, slit_name)
slit2detector = slit_wcs.get_transform('slit_frame', 'detector')
is_nirspec_ifu = is_nrs_ifu_lamp(input_model) or input_model.meta.exposure.type.lower() == 'nrs_ifu'
if is_nirspec_ifu:
bb = compute_bounding_box(slit2detector, wrange)
bb = compute_bounding_box(slit2detector, wavelength_range)
else:
slit = [s for s in open_slits if s.name == slit_name][0]
bb = compute_bounding_box(slit2detector, wrange,
slit_ymin=slit.ymin, slit_ymax=slit.ymax)
if slit_y_low is None or slit_y_high is None:
slit_y_low, slit_y_high = _get_y_range(input_model)
bb = compute_bounding_box(slit2detector, wavelength_range,
slit_ymin=slit_y_low, slit_ymax=slit_y_high)

slit_wcs.bounding_box = bb
return slit_wcs
Expand Down
89 changes: 89 additions & 0 deletions jwst/assign_wcs/tests/test_nirspec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test functions for NIRSPEC WCS - all modes.
"""
import functools
from math import cos, sin
import os.path

Expand Down Expand Up @@ -857,3 +858,91 @@ def test_functional_ifu_prism():
v3 /= 3600
assert_allclose(v2, ins_tab['xV2V3'])
assert_allclose(v3, ins_tab['yV2V3'])


def test_ifu_bbox():
bbox = {0: ((122.0908542999878, 1586.2584665188083),
(773.5411133037417, 825.1150258966278)),
1: ((140.3793485788431, 1606.8904629423566),
(1190.353197027459, 1243.0853605832503)),
2: ((120.0139534379125, 1583.9271768905855),
(724.3249534782219, 775.8104288584977)),
3: ((142.50252648927454, 1609.3106221382388),
(1239.4122720740888, 1292.288713688988)),
4: ((117.88884113088403, 1581.5517394150106),
(674.9787657901347, 726.3752061973377)),
5: ((144.57465414462143, 1611.688447569682),
(1288.4808318659427, 1341.5035313084197)),
6: ((115.8602297714846, 1579.27471654949),
(625.7982466386104, 677.1147840452901)),
7: ((146.7944728147906, 1614.2161842198498),
(1337.531525654835, 1390.7050687363856)),
8: ((113.86384530944383, 1577.0293086386203),
(576.5344359685643, 627.777022204828)),
9: ((149.0259581360621, 1616.7687282225652),
(1386.5118806905086, 1439.843598490326)),
10: ((111.91564190274217, 1574.8351095461135),
(527.229828693075, 578.402894851317)),
11: ((151.3053466801954, 1619.3720722471498),
(1435.423685040875, 1488.917203728964)),
12: ((109.8957204607345, 1572.570246400894),
(477.9699083444277, 529.0782087498488)),
13: ((153.5023503173659, 1621.9005029476564),
(1484.38405923062, 1538.0443479389924)),
14: ((107.98320121613297, 1570.411787034636),
(428.6704834494425, 479.7217241891257)),
15: ((155.77991404913857, 1624.5184927460925),
(1533.169633314481, 1586.9984359105376)),
16: ((106.10212081215678, 1568.286103827344),
(379.3860245240618, 430.3780648366697)),
17: ((158.23149941845386, 1627.305849064835),
(1582.0496119714928, 1636.0513450787032)),
18: ((104.09366374413436, 1566.030231370944),
(330.0822744105267, 381.01974582564395)),
19: ((160.4511021152353, 1629.888830991371),
(1630.7797743277185, 1684.9592727079018)),
20: ((102.25220592881234, 1563.9475099032868),
(280.7233309522168, 331.6093009077988)),
21: ((162.72784286205734, 1632.5257403739463),
(1679.6815760587567, 1734.03692957156)),
22: ((100.40115742738622, 1561.8476640376036),
(231.35443588323855, 282.19575854747006)),
23: ((165.05939163941662, 1635.2270773628682),
(1728.511467615387, 1783.0485841263735)),
24: ((98.45723949658425, 1559.6499479349648),
(182.0417295679079, 232.83530870639865)),
25: ((167.44628840053574, 1637.9923229870349),
(1777.2512197664128, 1831.971115503598)),
26: ((96.56508092457855, 1557.5079027818058),
(132.5285162704088, 183.27350269292484)),
27: ((169.8529496136358, 1640.778485168005),
(1826.028691168028, 1880.9336718824313)),
28: ((94.71390837793813, 1555.4048050512263),
(82.94691422559131, 133.63901517357235)),
29: ((172.3681094850081, 1643.685604697228),
(1874.8184744639657, 1929.9072657798927))}

hdul = create_nirspec_ifu_file("F290LP", "G140M")
im = datamodels.IFUImageModel(hdul)
im.meta.filename = "test_ifu.fits"
refs = create_reference_files(im)

pipe = nirspec.create_pipeline(im, refs, slit_y_range=[-.5, .5])
w = wcs.WCS(pipe)
im.meta.wcs = w

_, wrange = nirspec.spectral_order_wrange_from_model(im)
pipe = im.meta.wcs.pipeline

g2s = pipe[2].transform
transforms = [pipe[0].transform]
transforms.append(pipe[1].transform[1:])
transforms.append(astmodels.Identity(1))
transforms.append(astmodels.Identity(1))
transforms.extend([step.transform for step in pipe[4:-1]])

for sl in range(30):
transforms[2] = g2s.get_model(sl)
m = functools.reduce(lambda x, y: x | y, [tr.inverse for tr in transforms[:3][::-1]])
bbox_sl = nirspec.compute_bounding_box(m, wrange)
assert_allclose(bbox[sl], bbox_sl)
90 changes: 67 additions & 23 deletions jwst/assign_wcs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,14 +887,17 @@ def compute_footprint_spectral(model):
[np.nanmax(ra), np.nanmin(dec)],
[np.nanmax(ra), np.nanmax(dec)],
[np.nanmin(ra), np.nanmax(dec)]])
return footprint
lam_min = np.nanmin(lam)
lam_max = np.nanmax(lam)
return footprint, (lam_min, lam_max)


def update_s_region_spectral(model):
""" Update the S_REGION keyword.
"""
footprint = compute_footprint_spectral(model)
footprint, spectral_region = compute_footprint_spectral(model)
update_s_region_keyword(model, footprint)
model.meta.wcsinfo.spectral_region = spectral_region


def compute_footprint_nrs_slit(slit):
Expand All @@ -913,12 +916,16 @@ def compute_footprint_nrs_slit(slit):
ra, dec, lam = slit2world(virtual_corners_x,
virtual_corners_y,
input_lam)
return np.array([ra, dec]).T
footprint = np.array([ra, dec]).T
lam_min = np.nanmin(lam)
lam_max = np.nanmax(lam)
return footprint, (lam_min, lam_max)


def update_s_region_nrs_slit(slit):
footprint = compute_footprint_nrs_slit(slit)
footprint, spectral_region = compute_footprint_nrs_slit(slit)
update_s_region_keyword(slit, footprint)
slit.meta.wcsinfo.spectral_region = spectral_region


def update_s_region_keyword(model, footprint):
Expand All @@ -938,36 +945,71 @@ def update_s_region_keyword(model, footprint):
log.info("Update S_REGION to {}".format(model.meta.wcsinfo.s_region))


def _nanminmax(wcsobj):
x, y = grid_from_bounding_box(wcsobj.bounding_box)
ra, dec, lam = wcsobj(x, y)
return np.nanmin(ra), np.nanmax(ra), np.nanmin(dec), np.nanmax(dec)
def compute_footprint_nrs_ifu(dmodel, mod):
"""
Determine NIRSPEC IFU footprint using the instrument model.

For efficiency this function uses the transforms directly,
instead of the WCS object. The common transforms in the WCS
model chain are referenced and reused; only the slice specific
transforms are computed.

def compute_footprint_nrs_ifu(output_model, mod):
"""
determine NIRSPEC ifu footprint observations using the instrument model.
If the transforms change this function should be revised.

Parameters
----------
output_model : `~jwst.datamodels.IFUImageModel`
The output of assign_wcs.
mod : module
The imported ``nirspec`` module.

Returns
-------
footprint : ndarray
The spatial footprint
spectral_region : tuple
The wavelength range for the observation.
"""
wcs_list = mod.nrs_ifu_wcs(output_model)
ra_total = []
dec_total = []
for wcsobj in wcs_list:
rmin, rmax, dmin, dmax = _nanminmax(wcsobj)
ra_total.append((rmin, rmax))
dec_total.append((dmin, dmax))
ra_max = np.asarray(ra_total)[:, 1].max()
ra_min = np.asarray(ra_total)[:, 0].min()
dec_max = np.asarray(dec_total)[:, 1].max()
dec_min = np.asarray(dec_total)[:, 0].min()
lam_total = []
_, wrange = mod.spectral_order_wrange_from_model(dmodel)
pipe = dmodel.meta.wcs.pipeline

# Get the GWA to slit_frame transform
g2s = pipe[2].transform

# Construct a list of the transforms between coordinate frames.
# Set a place holder ``Identity`` transform at index 2 and 3.
# Update them with slice specific transforms.
transforms = [pipe[0].transform]
transforms.append(pipe[1].transform[1:])
transforms.append(astmodels.Identity(1))
transforms.append(astmodels.Identity(1))
transforms.extend([step.transform for step in pipe[4:-1]])

for sl in range(30):
transforms[2] = g2s.get_model(sl)
# Create the full transform from ``slit_frame`` to ``detector``.
# It is used to compute the bounding box.
m = functools.reduce(lambda x, y: x | y, [tr.inverse for tr in transforms[:3][::-1]])
bbox = mod.compute_bounding_box(m, wrange)
# Add the remaining transforms - from ``sli_frame`` to ``world``
transforms[3] = pipe[3].transform.get_model(sl) & astmodels.Identity(1)
mforw = functools.reduce(lambda x, y: x | y, transforms)
x1, y1 = grid_from_bounding_box(bbox)
ra, dec, lam = mforw(x1, y1)
ra_total.extend(np.ravel(ra))
dec_total.extend(np.ravel(dec))
lam_total.extend(np.ravel(lam))
ra_max = np.nanmax(ra_total)
ra_min = np.nanmin(ra_total)
dec_max = np.nanmax(dec_total)
dec_min = np.nanmin(dec_total)
lam_max = np.nanmax(lam_total)
lam_min = np.nanmin(lam_total)
footprint = np.array([ra_min, dec_min, ra_max, dec_min, ra_max, dec_max, ra_min, dec_max])
return footprint
return footprint, (lam_min, lam_max)


def update_s_region_nrs_ifu(output_model, mod):
Expand All @@ -981,8 +1023,9 @@ def update_s_region_nrs_ifu(output_model, mod):
mod : module
The imported ``nirspec`` module.
"""
footprint = compute_footprint_nrs_ifu(output_model, mod)
footprint, spectral_region = compute_footprint_nrs_ifu(output_model, mod)
update_s_region_keyword(output_model, footprint)
output_model.meta.wcsinfo.spectral_region = spectral_region


def update_s_region_mrs(output_model):
Expand All @@ -994,8 +1037,9 @@ def update_s_region_mrs(output_model):
output_model : `~jwst.datamodels.IFUImageModel`
The output of assign_wcs.
"""
footprint = compute_footprint_spectral(output_model)
footprint, spectral_region = compute_footprint_spectral(output_model)
update_s_region_keyword(output_model, footprint)
output_model.meta.wcsinfo.spectral_region = spectral_region


def velocity_correction(velosys):
Expand Down