From fc20b8c6509033ea86ab05bb824a649ebec31c9e Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 3 Jul 2024 14:22:26 -0400 Subject: [PATCH] tweakreg into stcal --- jwst/assign_wcs/__init__.py | 3 +- jwst/assign_wcs/assign_wcs_step.py | 4 +- jwst/assign_wcs/util.py | 164 +------------ jwst/resample/resample.py | 2 +- jwst/resample/resample_utils.py | 31 ++- jwst/tweakreg/astrometric_utils.py | 207 ---------------- jwst/tweakreg/tests/test_amutils.py | 75 ------ jwst/tweakreg/tests/test_tweakreg.py | 7 +- jwst/tweakreg/tests/test_utils.py | 3 +- jwst/tweakreg/tweakreg_step.py | 344 ++++++--------------------- jwst/tweakreg/utils.py | 43 +--- pyproject.toml | 2 +- 12 files changed, 104 insertions(+), 781 deletions(-) delete mode 100644 jwst/tweakreg/astrometric_utils.py delete mode 100644 jwst/tweakreg/tests/test_amutils.py diff --git a/jwst/assign_wcs/__init__.py b/jwst/assign_wcs/__init__.py index bd27aca1fa4..728738dd495 100644 --- a/jwst/assign_wcs/__init__.py +++ b/jwst/assign_wcs/__init__.py @@ -1,7 +1,6 @@ from .assign_wcs_step import AssignWcsStep from .nirspec import (nrs_wcs_set_input, nrs_ifu_wcs, get_spectral_order_wrange) from .niriss import niriss_soss_set_input -from .util import update_fits_wcsinfo __all__ = ['AssignWcsStep', "nrs_wcs_set_input", "nrs_ifu_wcs", "get_spectral_order_wrange", - "niriss_soss_set_input", "update_fits_wcsinfo"] + "niriss_soss_set_input"] diff --git a/jwst/assign_wcs/assign_wcs_step.py b/jwst/assign_wcs/assign_wcs_step.py index 324bef826d4..025dff952bc 100755 --- a/jwst/assign_wcs/assign_wcs_step.py +++ b/jwst/assign_wcs/assign_wcs_step.py @@ -5,8 +5,8 @@ from ..lib.exposure_types import IMAGING_TYPES import logging from .assign_wcs import load_wcs -from .util import MSAFileError, update_fits_wcsinfo -from .util import wfss_imaging_wcs, wcs_bbox_from_shape +from .util import MSAFileError +from .util import wfss_imaging_wcs, wcs_bbox_from_shape, update_fits_wcsinfo from .nircam import imaging as nircam_imaging from .niriss import imaging as niriss_imaging diff --git a/jwst/assign_wcs/util.py b/jwst/assign_wcs/util.py index 2767182fe09..91e1c390fa5 100644 --- a/jwst/assign_wcs/util.py +++ b/jwst/assign_wcs/util.py @@ -34,9 +34,9 @@ _MAX_SIP_DEGREE = 6 -__all__ = ["reproject", "wcs_from_footprints", "velocity_correction", +__all__ = ["reproject", "velocity_correction", "MSAFileError", "NoDataOnDetectorError", "compute_scale", - "calc_rotation_matrix", "wrap_ra", "update_fits_wcsinfo"] + "calc_rotation_matrix", "wrap_ra", "update_fits_wcsinfo",] class MSAFileError(Exception): @@ -197,164 +197,6 @@ def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> return [pc1_1, pc1_2, pc2_1, pc2_2] -def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=None, - pscale_ratio=None, pscale=None, rotation=None, - shape=None, crpix=None, crval=None, wcslist=None): - """ - Create a WCS from a list of input data models. - - A fiducial point in the output coordinate frame is created from the - footprints of all WCS objects. For a spatial frame this is the center - of the union of the footprints. For a spectral frame the fiducial is in - the beginning of the footprint range. - If ``refmodel`` is None, the first WCS object in the list is considered - a reference. The output coordinate frame and projection (for celestial frames) - is taken from ``refmodel``. - If ``transform`` is not supplied, a compound transform is created using - CDELTs and PC. - If ``bounding_box`` is not supplied, the bounding_box of the new WCS is computed - from bounding_box of all input WCSs. - - Parameters - ---------- - dmodels : list of `~jwst.datamodels.JwstDataModel` - A list of data models. - refmodel : `~jwst.datamodels.JwstDataModel`, optional - This model's WCS is used as a reference. - WCS. The output coordinate frame, the projection and a - scaling and rotation transform is created from it. If not supplied - the first model in the list is used as ``refmodel``. - transform : `~astropy.modeling.core.Model`, optional - A transform, passed to :meth:`~gwcs.wcstools.wcs_from_fiducial` - If not supplied Scaling | Rotation is computed from ``refmodel``. - bounding_box : tuple, optional - Bounding_box of the new WCS. - If not supplied it is computed from the bounding_box of all inputs. - pscale_ratio : float, None, optional - Ratio of input to output pixel scale. Ignored when either - ``transform`` or ``pscale`` are provided. - pscale : float, None, optional - Absolute pixel scale in degrees. When provided, overrides - ``pscale_ratio``. Ignored when ``transform`` is provided. - rotation : float, None, optional - Position angle of output image’s Y-axis relative to North. - A value of 0.0 would orient the final output image to be North up. - The default of `None` specifies that the images will not be rotated, - but will instead be resampled in the default orientation for the camera - with the x and y axes of the resampled image corresponding - approximately to the detector axes. Ignored when ``transform`` is - provided. - shape : tuple of int, None, optional - Shape of the image (data array) using ``numpy.ndarray`` convention - (``ny`` first and ``nx`` second). This value will be assigned to - ``pixel_shape`` and ``array_shape`` properties of the returned - WCS object. - crpix : tuple of float, None, optional - Position of the reference pixel in the image array. If ``crpix`` is not - specified, it will be set to the center of the bounding box of the - returned WCS object. - crval : tuple of float, None, optional - Right ascension and declination of the reference pixel. Automatically - computed if not provided. - - """ - bb = bounding_box - if wcslist is None: - wcslist = [im.meta.wcs for im in dmodels] - - if not isiterable(wcslist): - raise ValueError("Expected 'wcslist' to be an iterable of WCS objects.") - - if not all([isinstance(w, WCS) for w in wcslist]): - raise TypeError("All items in wcslist are to be instances of gwcs.WCS.") - - if refmodel is None: - refmodel = dmodels[0] - else: - if not isinstance(refmodel, JwstDataModel): - raise TypeError("Expected refmodel to be an instance of DataModel.") - - fiducial = compute_fiducial(wcslist, bb) - if crval is not None: - # overwrite spatial axes with user-provided CRVAL: - i = 0 - for k, axt in enumerate(wcslist[0].output_frame.axes_type): - if axt == 'SPATIAL': - fiducial[k] = crval[i] - i += 1 - - ref_fiducial = np.array([refmodel.meta.wcsinfo.ra_ref, refmodel.meta.wcsinfo.dec_ref]) - - prj = astmodels.Pix2Sky_TAN() - - if transform is None: - transform = [] - wcsinfo = pointing.wcsinfo_from_model(refmodel) - sky_axes, spec, other = gwutils.get_axes(wcsinfo) - - # Need to put the rotation matrix (List[float, float, float, float]) - # returned from calc_rotation_matrix into the correct shape for - # constructing the transformation - v3yangle = np.deg2rad(refmodel.meta.wcsinfo.v3yangle) - vparity = refmodel.meta.wcsinfo.vparity - if rotation is None: - roll_ref = np.deg2rad(refmodel.meta.wcsinfo.roll_ref) - else: - roll_ref = np.deg2rad(rotation) + (vparity * v3yangle) - - pc = np.reshape( - calc_rotation_matrix(roll_ref, v3yangle, vparity=vparity), - (2, 2) - ) - - rotation = astmodels.AffineTransformation2D(pc, name='pc_rotation_matrix') - transform.append(rotation) - - if sky_axes: - if not pscale: - pscale = compute_scale(refmodel.meta.wcs, ref_fiducial, - pscale_ratio=pscale_ratio) - transform.append(astmodels.Scale(pscale, name='cdelt1') & astmodels.Scale(pscale, name='cdelt2')) - - if transform: - transform = functools.reduce(lambda x, y: x | y, transform) - - out_frame = refmodel.meta.wcs.output_frame - input_frame = refmodel.meta.wcs.input_frame - wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj, - transform=transform, input_frame=input_frame) - - footprints = [w.footprint().T for w in wcslist] - domain_bounds = np.hstack([wnew.backward_transform(*f) for f in footprints]) - axis_min_values = np.min(domain_bounds, axis=1) - domain_bounds = (domain_bounds.T - axis_min_values).T - - output_bounding_box = [] - for axis in out_frame.axes_order: - axis_min, axis_max = domain_bounds[axis].min(), domain_bounds[axis].max() - output_bounding_box.append((axis_min, axis_max)) - - output_bounding_box = tuple(output_bounding_box) - if crpix is None: - offset1, offset2 = wnew.backward_transform(*fiducial) - offset1 -= axis_min_values[0] - offset2 -= axis_min_values[1] - else: - offset1, offset2 = crpix - offsets = astmodels.Shift(-offset1, name='crpix1') & astmodels.Shift(-offset2, name='crpix2') - - wnew.insert_transform('detector', offsets, after=True) - wnew.bounding_box = output_bounding_box - - if shape is None: - shape = [int(axs[1] - axs[0] + 0.5) for axs in output_bounding_box[::-1]] - - wnew.pixel_shape = shape[::-1] - wnew.array_shape = shape - - return wnew - - def compute_fiducial(wcslist, bounding_box=None): """ For a celestial footprint this is the center. @@ -1283,7 +1125,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, Parameters ---------- - datamodel : `ImageModel` The input data model for imaging or WFSS mode whose ``meta.wcsinfo`` field should be updated from GWCS. By default, ``datamodel.meta.wcs`` @@ -1386,7 +1227,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, Notes ----- - Use of this requires a judicious choice of required accuracies. Attempts to use higher degrees (~7 or higher) will typically fail due to floating point problems that arise with high powers. diff --git a/jwst/resample/resample.py b/jwst/resample/resample.py index 1a9ff734c35..641e6dc88fc 100644 --- a/jwst/resample/resample.py +++ b/jwst/resample/resample.py @@ -14,7 +14,7 @@ from jwst.datamodels import ModelContainer from . import gwcs_drizzle -from . import resample_utils +from jwst.resample import resample_utils from ..model_blender import blendmeta log = logging.getLogger(__name__) diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index ed4447d8b55..0067ae2af3a 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -9,7 +9,8 @@ from stdatamodels.dqflags import interpret_bit_flags from stdatamodels.jwst.datamodels.dqflags import pixel -from jwst.assign_wcs.util import wcs_from_footprints, wcs_bbox_from_shape +from jwst.assign_wcs.util import wcs_bbox_from_shape +from stcal.alignment import util log = logging.getLogger(__name__) @@ -22,21 +23,23 @@ def make_output_wcs(input_models, ref_wcs=None, pscale_ratio=None, pscale=None, rotation=None, shape=None, crpix=None, crval=None): - """ Generate output WCS here based on footprints of all input WCS objects + """Generate output WCS here based on footprints of all input WCS objects. + Parameters ---------- - input_models : list of `~jwst.datamodel.JwstDataModel` + input_models : list of `DataModel objects` Each datamodel must have a ~gwcs.WCS object. pscale_ratio : float, optional - Ratio of input to output pixel scale. Ignored when ``pscale`` is provided. + Ratio of input to output pixel scale. Ignored when ``pscale`` + is provided. pscale : float, None, optional Absolute pixel scale in degrees. When provided, overrides ``pscale_ratio``. rotation : float, None, optional - Position angle of output image’s Y-axis relative to North. + Position angle of output image Y-axis relative to North. A value of 0.0 would orient the final output image to be North up. The default of `None` specifies that the images will not be rotated, but will instead be resampled in the default orientation for the camera @@ -50,7 +53,7 @@ def make_output_wcs(input_models, ref_wcs=None, WCS object. crpix : tuple of float, None, optional - Position of the reference pixel in the image array. If ``crpix`` is not + Position of the reference pixel in the image array. If ``crpix`` is not specified, it will be set to the center of the bounding box of the returned WCS object. @@ -62,7 +65,6 @@ def make_output_wcs(input_models, ref_wcs=None, ------- output_wcs : object WCS object, with defined domain, covering entire set of input frames - """ if ref_wcs is None: wcslist = [i.meta.wcs for i in input_models] @@ -72,10 +74,11 @@ def make_output_wcs(input_models, ref_wcs=None, naxes = wcslist[0].output_frame.naxes if naxes != 2: - raise RuntimeError("Output WCS needs 2 spatial axes. " - f"{wcslist[0]} has {naxes}.") + msg = f"Output WCS needs 2 spatial axes \ + but the supplied WCS has {naxes} axes." + raise RuntimeError(msg) - output_wcs = wcs_from_footprints( + output_wcs = util.wcs_from_footprints( input_models, pscale_ratio=pscale_ratio, pscale=pscale, @@ -88,15 +91,17 @@ def make_output_wcs(input_models, ref_wcs=None, else: naxes = ref_wcs.output_frame.naxes if naxes != 2: - raise RuntimeError("Output WCS needs 2 spatial axes but the " - f"supplied WCS has {naxes} axes.") + msg = f"Output WCS needs 2 spatial axes \ + but the supplied WCS has {naxes} axes." + raise RuntimeError(msg) output_wcs = deepcopy(ref_wcs) if shape is not None: output_wcs.array_shape = shape # Check that the output data shape has no zero length dimensions if not np.prod(output_wcs.array_shape): - raise ValueError(f"Invalid output frame shape: {tuple(output_wcs.array_shape)}") + msg = f"Invalid output frame shape: {tuple(output_wcs.array_shape)}" + raise ValueError(msg) return output_wcs diff --git a/jwst/tweakreg/astrometric_utils.py b/jwst/tweakreg/astrometric_utils.py deleted file mode 100644 index a3034cb6c81..00000000000 --- a/jwst/tweakreg/astrometric_utils.py +++ /dev/null @@ -1,207 +0,0 @@ -import os -import requests - -from astropy import table -from astropy.time import Time -from astropy.table import Table -from astropy.coordinates import SkyCoord -from astropy import units as u - -from ..resample import resample_utils -from ..assign_wcs import util as wcsutil - -ASTROMETRIC_CAT_ENVVAR = "ASTROMETRIC_CATALOG_URL" -DEF_CAT_URL = 'http://gsss.stsci.edu/webservices' - -if ASTROMETRIC_CAT_ENVVAR in os.environ: - SERVICELOCATION = os.environ[ASTROMETRIC_CAT_ENVVAR] -else: - SERVICELOCATION = DEF_CAT_URL - -TIMEOUT = 30.0 # in seconds - -""" - -Primary function for creating an astrometric reference catalog. - -""" - - -__all__ = ["TIMEOUT", "compute_radius", "create_astrometric_catalog", "get_catalog"] - - -def create_astrometric_catalog(input_models, catalog="GAIADR3", output="ref_cat.ecsv", - gaia_only=False, table_format="ascii.ecsv", - existing_wcs=None, num_sources=None, epoch=None): - """Create an astrometric catalog that covers the inputs' field-of-view. - - Parameters - ---------- - input_models : list of `~jwst.datamodel.JwstDataModel` - Each datamodel must have a ~gwcs.WCS object. - - catalog : str, optional - Name of catalog to extract astrometric positions for sources in the - input images' field-of-view. Default: GAIADR3. Options available are - documented on the catalog web page. - - output : str, optional - Filename to give to the astrometric catalog read in from the master - catalog web service. If None, no file will be written out. - - gaia_only : bool, optional - Specify whether or not to only use sources from GAIA in output catalog - - existing_wcs : model - existing WCS object specified by the user as generated by - `resample.resample_utils.make_output_wcs` - - num_sources : int - Maximum number of brightest/faintest sources to return in catalog. - If `num_sources` is negative, return that number of the faintest - sources. By default, all sources are returned. - - epoch : float, optional - Reference epoch used to update the coordinates for proper motion - (in decimal year). If `None` then the epoch is obtained from - the metadata. - - Notes - ----- - This function will point to astrometric catalog web service defined - through the use of the ASTROMETRIC_CATALOG_URL environment variable. - - Returns - ------- - ref_table : `~astropy.table.Table` - Astropy Table object of the catalog - - """ - - # start by creating a composite field-of-view for all inputs - # This default output WCS will have the same plate-scale and orientation - # as the first member in the list. - # Fortunately, for alignment, this doesn't matter since no resampling of - # data will be performed. - if existing_wcs is not None: - outwcs = existing_wcs - else: - outwcs = resample_utils.make_output_wcs(input_models) - radius, fiducial = compute_radius(outwcs) - - # perform query for this field-of-view - epoch = ( - epoch - if epoch is not None - else Time(input_models[0].meta.observation.date).decimalyear - ) - ref_dict = get_catalog(fiducial[0], fiducial[1], epoch=epoch, sr=radius, catalog=catalog) - if len(ref_dict) == 0: - return ref_dict - - colnames = ('ra', 'dec', 'mag', 'objID', 'epoch') - ref_table = ref_dict[colnames] - - # Add catalog name as meta data - ref_table.meta['catalog'] = catalog - ref_table.meta['gaia_only'] = gaia_only - - # rename coordinate columns to be consistent with tweakwcs - ref_table.rename_column('ra', 'RA') - ref_table.rename_column('dec', 'DEC') - - # Append GAIA ID as a new column to the table... - gaia_sources = [] - for source in ref_dict: - if 'GAIAsourceID' in source: - g = source['GAIAsourceID'] - if gaia_only and g.strip() == '': - continue - else: - g = "-1" # indicator for no source ID extracted - gaia_sources.append(g) - - gaia_col = table.Column(data=gaia_sources, name='GaiaID', dtype='U25') - ref_table.add_column(gaia_col) - - # sort table by magnitude, fainter to brightest - ref_table.sort('mag', reverse=True) - - # If specified by the use through the 'num_sources' parameter, - # trim the returned catalog down to just the brightest 'num_sources' sources - # Should 'num_sources' be a negative value, it will return the faintest - # 'num_sources' sources. - if num_sources is not None: - indx = -1 * num_sources - ref_table = ref_table[:indx] if num_sources < 0 else ref_table[indx:] - - # Write out table to a file, if specified - if output is not None: - ref_table.write(output, format=table_format, overwrite=True) - - return ref_table - - -""" - -Utility functions for creating an astrometric reference catalog. - -""" - - -def compute_radius(wcs): - """Compute the radius from the center to the furthest edge of the WCS.""" - - fiducial = wcsutil.compute_fiducial([wcs], wcs.bounding_box) - img_center = SkyCoord(ra=fiducial[0] * u.degree, dec=fiducial[1] * u.degree) - wcs_foot = wcs.footprint() - img_corners = SkyCoord(ra=wcs_foot[:, 0] * u.degree, - dec=wcs_foot[:, 1] * u.degree) - radius = img_center.separation(img_corners).max().value - - return radius, fiducial - - -def get_catalog(ra, dec, epoch=2016.0, sr=0.1, catalog='GAIADR3'): - """ Extract catalog from VO web service. - - Parameters - ---------- - ra : float - Right Ascension (RA) of center of field-of-view (in decimal degrees) - - dec : float - Declination (Dec) of center of field-of-view (in decimal degrees) - - epoch : float, optional - Reference epoch used to update the coordinates for proper motion - (in decimal year). Default: 2016.0 - - sr : float, optional - Search radius (in decimal degrees) from field-of-view center to use - for sources from catalog. Default: 0.1 degrees - - catalog : str, optional - Name of catalog to query, as defined by web-service. Default: 'GAIADR3' - - Returns - ------- - csv : CSV object - CSV object of returned sources with all columns as provided by catalog - - """ - service_type = "vo/CatalogSearch.aspx" - spec_str = "RA={}&DEC={}&EPOCH={}&SR={}&FORMAT={}&CAT={}&MINDET=5" - headers = {"Content-Type": "text/csv"} - fmt = "CSV" - - spec = spec_str.format(ra, dec, epoch, sr, fmt, catalog) - service_url = f"{SERVICELOCATION}/{service_type}?{spec}" - rawcat = requests.get(service_url, headers=headers, timeout=TIMEOUT) - r_contents = rawcat.content.decode() # convert from bytes to a String - rstr = r_contents.split('\r\n') - # remove initial line describing the number of sources returned - # CRITICAL to proper interpretation of CSV data - del rstr[0] - - return Table.read(rstr, format='csv') diff --git a/jwst/tweakreg/tests/test_amutils.py b/jwst/tweakreg/tests/test_amutils.py deleted file mode 100644 index 3f2e40fcaf1..00000000000 --- a/jwst/tweakreg/tests/test_amutils.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Test astrometric utility functions for alignment""" -import os - -import asdf -import numpy as np -import pytest - -from jwst.tweakreg import astrometric_utils as amutils - - -# Define input GWCS specification to be used for these tests -WCS_NAME = 'mosaic_long_i2d_gwcs.asdf' # Derived using B7.5 Level 3 product -EXPECTED_NUM_SOURCES = 2469 -EXPECTED_RADIUS = 0.02564497890604383 -TEST_CATALOG = 'GAIADR3' - - -@pytest.fixture(scope="module") -def wcsobj(): - path = os.path.join(os.path.dirname(__file__), WCS_NAME) - with asdf.open(path) as asdf_file: - wcs = asdf_file['wcs'] - - return wcs - - -def test_radius(wcsobj): - # compute radius - radius, fiducial = amutils.compute_radius(wcsobj) - - # check results - np.testing.assert_allclose(radius, EXPECTED_RADIUS, rtol=1e-6) - - -def test_get_catalog(wcsobj): - # Get radius and fiducial - radius, fiducial = amutils.compute_radius(wcsobj) - - # Get the catalog - cat = amutils.get_catalog(fiducial[0], fiducial[1], sr=radius, - catalog=TEST_CATALOG) - - assert len(cat) == EXPECTED_NUM_SOURCES - - -def test_create_catalog(wcsobj): - # Create catalog - gcat = amutils.create_astrometric_catalog( - None, - existing_wcs=wcsobj, - catalog=TEST_CATALOG, - output=None, - epoch='2016.0', - ) - # check that we got expected number of sources - assert len(gcat) == EXPECTED_NUM_SOURCES - - -def test_create_catalog_graceful_failure(wcsobj): - ''' - Ensure catalog retuns zero sources instead of failing outright - when the bounding box is too small to find any sources - ''' - wcsobj.bounding_box = ((0, 0.5), (0, 0.5)) - - # Create catalog - gcat = amutils.create_astrometric_catalog( - None, - existing_wcs=wcsobj, - catalog=TEST_CATALOG, - output=None, - epoch='2016.0', - ) - # check that we got expected number of sources - assert len(gcat) == 0 diff --git a/jwst/tweakreg/tests/test_tweakreg.py b/jwst/tweakreg/tests/test_tweakreg.py index da5e5bd8945..a6a5bf83578 100644 --- a/jwst/tweakreg/tests/test_tweakreg.py +++ b/jwst/tweakreg/tests/test_tweakreg.py @@ -14,12 +14,14 @@ from jwst.datamodels import ModelContainer from jwst.tweakreg import tweakreg_step from jwst.tweakreg import tweakreg_catalog -from jwst.tweakreg.utils import _wcsinfo_from_wcs_transform +from stcal.tweakreg.utils import _wcsinfo_from_wcs_transform +from stcal.tweakreg import tweakreg as twk BKG_LEVEL = 0.001 N_EXAMPLE_SOURCES = 21 N_CUSTOM_SOURCES = 15 +REFCAT = "GAIADR3" @pytest.fixture @@ -213,6 +215,7 @@ def test_src_confusion_pars(example_input, alignment_type): pars = { f"{alignment_type}separation": 1.0, f"{alignment_type}tolerance": 1.0, + "abs_refcat": REFCAT, } step = tweakreg_step.TweakRegStep(**pars) result = step(example_input) @@ -338,7 +341,7 @@ def patched_construct_wcs_corrector(model, catalog, _seen=[]): raise ValueError("done testing") return None - monkeypatch.setattr(tweakreg_step, "_construct_wcs_corrector", patched_construct_wcs_corrector) + monkeypatch.setattr(twk, "construct_wcs_corrector", patched_construct_wcs_corrector) with pytest.raises(ValueError, match="done testing"): step(str(asn_path)) diff --git a/jwst/tweakreg/tests/test_utils.py b/jwst/tweakreg/tests/test_utils.py index e05fca7e98b..05fb3969fce 100644 --- a/jwst/tweakreg/tests/test_utils.py +++ b/jwst/tweakreg/tests/test_utils.py @@ -14,9 +14,8 @@ from jwst.tweakreg.utils import ( adjust_wcs, transfer_wcs_correction, - _wcsinfo_from_wcs_transform ) - +from stcal.tweakreg.utils import _wcsinfo_from_wcs_transform data_path = path.split(path.abspath(data.__file__))[0] diff --git a/jwst/tweakreg/tweakreg_step.py b/jwst/tweakreg/tweakreg_step.py index d928cb166ae..1efee6c5201 100644 --- a/jwst/tweakreg/tweakreg_step.py +++ b/jwst/tweakreg/tweakreg_step.py @@ -10,23 +10,19 @@ from astropy import units as u from astropy.coordinates import SkyCoord from astropy.table import Table -from astropy.time import Time -from tweakwcs.imalign import align_wcs from tweakwcs.correctors import JWSTWCSCorrector -from tweakwcs.matchutils import XYXYMatch + +import stcal.tweakreg.tweakreg as twk +from stcal.alignment import update_s_region_imaging from jwst.datamodels import ModelContainer +from jwst.assign_wcs.util import update_fits_wcsinfo # LOCAL from ..stpipe import Step -from ..assign_wcs.util import update_fits_wcsinfo, update_s_region_imaging, wcs_from_footprints -from .astrometric_utils import create_astrometric_catalog from .tweakreg_catalog import make_tweakreg_catalog -_SQRT2 = math.sqrt(2.0) - - def _oxford_or_str_join(str_list): nelem = len(str_list) if not nelem: @@ -136,26 +132,6 @@ class TweakRegStep(Step): def process(self, input): images = ModelContainer(input) - if self.separation <= _SQRT2 * self.tolerance: - self.log.error( - "Parameter 'separation' must be larger than 'tolerance' by at " - "least a factor of sqrt(2) to avoid source confusion." - ) - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - self.log.warning("Skipping 'TweakRegStep' step.") - return input - - if self.abs_separation <= _SQRT2 * self.abs_tolerance: - self.log.error( - "Parameter 'abs_separation' must be larger than 'abs_tolerance' " - "by at least a factor of sqrt(2) to avoid source confusion." - ) - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - self.log.warning("Skipping 'TweakRegStep' step.") - return input - if len(images) == 0: raise ValueError("Input must contain at least one image model.") @@ -214,7 +190,7 @@ def process(self, input): # pre-allocate collectors (same length and order as images) correctors = [None] * len(images) - # Build the catalog and corrector for each input images + # Build the catalog for each input image for (model_index, image_model) in enumerate(images): # now that the model is open, check it's metadata for a custom catalog # only if it's not listed in the catdict @@ -266,237 +242,78 @@ def process(self, input): image_model.meta.tweakreg_catalog = self._write_catalog(catalog, filename) # construct the corrector since the model is open (and already has a group_id) - correctors[model_index] = _construct_wcs_corrector(image_model, catalog) + correctors[model_index] = twk.construct_wcs_corrector(image_model, catalog) self.log.info('') self.log.info("Number of image groups to be aligned: {:d}." .format(n_groups)) - # keep track of if 'local' alignment failed, even if this - # fails, absolute alignment might be run (if so configured) - local_align_failed = False - - # if we have >1 group of images, align them to each other - if n_groups > 1: - - # align images: - xyxymatch = XYXYMatch( - searchrad=self.searchrad, - separation=self.separation, - use2dhist=self.use2dhist, - tolerance=self.tolerance, - xoffset=self.xoffset, - yoffset=self.yoffset - ) - - try: - align_wcs( - correctors, - refcat=None, - enforce_user_order=self.enforce_user_order, - expand_refcat=self.expand_refcat, - minobj=self.minobj, - match=xyxymatch, - fitgeom=self.fitgeometry, - nclip=self.nclip, - sigma=(self.sigma, 'rmse') - ) - - except ValueError as e: - msg = e.args[0] - if (msg == "Too few input images (or groups of images) with " - "non-empty catalogs."): - # we need at least two exposures to perform image alignment - self.log.warning(msg) - self.log.warning("At least two exposures are required for " - "image alignment.") - self.log.warning("Nothing to do. Skipping 'TweakRegStep'...") - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - if not align_to_abs_refcat: - self.skip = True - return images - local_align_failed = True - else: - raise e - - except RuntimeError as e: - msg = e.args[0] - if msg.startswith("Number of output coordinates exceeded allocation"): - # we need at least two exposures to perform image alignment - self.log.error(msg) - self.log.error("Multiple sources within specified tolerance " - "matched to a single reference source. Try to " - "adjust 'tolerance' and/or 'separation' parameters.") - self.log.warning("Skipping 'TweakRegStep'...") - self.skip = True - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - return images - else: - raise e - - if not local_align_failed and not self._is_wcs_correction_small(correctors): - if align_to_abs_refcat: - self.log.warning("Skipping relative alignment (stage 1)...") - else: - self.log.warning("Skipping 'TweakRegStep'...") - self.skip = True - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - return images - - if align_to_abs_refcat: - # now, align things to the reference catalog - # this can occur after alignment between groups (only if >1 group) - - # Get catalog of GAIA sources for the field - # - # NOTE: If desired, the pipeline can write out the reference - # catalog as a separate product with a name based on - # whatever convention is determined by the JWST Cal Working - # Group. - - if self.save_abs_catalog: - if self.output_dir is None: - output_name = 'fit_{}_ref.ecsv'.format(self.abs_refcat.lower()) - else: - output_name = path.join(self.output_dir, 'fit_{}_ref.ecsv'.format(self.abs_refcat.lower())) + # wrapper to stcal tweakreg routines + # step skip conditions should throw TweakregError from stcal + try: + # relative alignment of images to each other (if more than one group) + if n_groups > 1: + correctors, local_align_failed = \ + twk.relative_align(correctors, + enforce_user_order=self.enforce_user_order, + expand_refcat=self.expand_refcat, + minobj=self.minobj, + fitgeometry=self.fitgeometry, + nclip=self.nclip, + sigma=self.sigma, + searchrad=self.searchrad, + use2dhist=self.use2dhist, + separation=self.separation, + tolerance=self.tolerance, + xoffset=self.xoffset, + yoffset=self.yoffset, + align_to_abs_refcat=align_to_abs_refcat) else: - output_name = None - - # initial shift to be used with absolute astrometry - self.abs_xoffset = 0 - self.abs_yoffset = 0 - - self.abs_refcat = self.abs_refcat.strip() - gaia_cat_name = self.abs_refcat.upper() - - if gaia_cat_name in SINGLE_GROUP_REFCAT: - ref_model = images[0] - - epoch = Time(ref_model.meta.observation.date).decimalyear - - # combine all aligned wcs to compute a new footprint to - # filter the absolute catalog sources - combined_wcs = wcs_from_footprints( - None, - refmodel=ref_model, - wcslist=[corrector.wcs for corrector in correctors], - ) + local_align_failed = True + + # absolute alignment to the reference catalog + # can (and does) occur after alignment between groups + if align_to_abs_refcat: + correctors = \ + twk.absolute_align(correctors, self.abs_refcat, images[0], + abs_minobj=self.abs_minobj, + abs_fitgeometry=self.abs_fitgeometry, + abs_nclip=self.abs_nclip, + abs_sigma=self.abs_sigma, + abs_searchrad=self.abs_searchrad, + abs_use2dhist=self.abs_use2dhist, + abs_separation=self.abs_separation, + abs_tolerance=self.abs_tolerance, + save_abs_catalog=self.save_abs_catalog, + abs_catalog_output_dir=self.output_dir, + local_align_failed=local_align_failed, + ) + + # one final pass through all the models to update them based + # on the results of this step + self._apply_tweakreg_solution(images, correctors, + align_to_abs_refcat=align_to_abs_refcat) + + except twk.TweakregError as e: + self.log.error(str(e)) + for model in images: + model.meta.cal_step.tweakreg = "SKIPPED" + return images - ref_cat = create_astrometric_catalog( - None, - gaia_cat_name, - existing_wcs=combined_wcs, - output=output_name, - epoch=epoch, - ) + return images - elif path.isfile(self.abs_refcat): - ref_cat = Table.read(self.abs_refcat) - else: - raise ValueError("'abs_refcat' must be a path to an " - "existing file name or one of the supported " - f"reference catalogs: {_SINGLE_GROUP_REFCAT_STR}.") - - # Check that there are enough GAIA sources for a reliable/valid fit - num_ref = len(ref_cat) - if num_ref < self.abs_minobj: - self.log.warning( - f"Not enough sources ({num_ref}) in the reference catalog " - "for the single-group alignment step to perform a fit. " - f"Skipping alignment to the {self.abs_refcat} reference " - "catalog!" - ) - else: - # align images: - # Update to separation needed to prevent confusion of sources - # from overlapping images where centering is not consistent or - # for the possibility that errors still exist in relative overlap. - xyxymatch_gaia = XYXYMatch( - searchrad=self.abs_searchrad, - separation=self.abs_separation, - use2dhist=self.abs_use2dhist, - tolerance=self.abs_tolerance, - xoffset=self.abs_xoffset, - yoffset=self.abs_yoffset - ) + def _apply_tweakreg_solution(self, + images: ModelContainer, + correctors: list[JWSTWCSCorrector], + align_to_abs_refcat: bool = False, + ) -> ModelContainer: - # Set group_id to same value so all get fit as one observation - # The assigned value, 987654, has been hard-coded to make it - # easy to recognize when alignment to GAIA was being performed - # as opposed to the group_id values used for relative alignment - # earlier in this step. - for corrector in correctors: - corrector.meta['group_id'] = 987654 - if ('fit_info' in corrector.meta and - 'REFERENCE' in corrector.meta['fit_info']['status']): - del corrector.meta['fit_info'] - - # Perform fit - try: - align_wcs( - correctors, - refcat=ref_cat, - enforce_user_order=True, - expand_refcat=False, - minobj=self.abs_minobj, - match=xyxymatch_gaia, - fitgeom=self.abs_fitgeometry, - nclip=self.abs_nclip, - sigma=(self.abs_sigma, 'rmse') - ) - except ValueError as e: - msg = e.args[0] - if (msg == "Too few input images (or groups of images) with " - "non-empty catalogs."): - # we need at least two exposures to perform image alignment - self.log.warning(msg) - self.log.warning( - "At least one exposure is required to align images " - "to an absolute reference catalog. Alignment to an " - "absolute reference catalog will not be performed." - ) - if local_align_failed or n_groups == 1: - self.log.warning("Nothing to do. Skipping 'TweakRegStep'...") - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - self.skip = True - return images - else: - raise e - - except RuntimeError as e: - msg = e.args[0] - if msg.startswith("Number of output coordinates exceeded allocation"): - # we need at least two exposures to perform image alignment - self.log.error(msg) - self.log.error( - "Multiple sources within specified tolerance " - "matched to a single reference source. Try to " - "adjust 'tolerance' and/or 'separation' parameters." - "Alignment to an absolute reference catalog will " - "not be performed." - ) - if local_align_failed or n_groups == 1: - self.log.warning("Skipping 'TweakRegStep'...") - self.skip = True - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - return images - else: - raise e - - # one final pass through all the models to update them based - # on the results of this step for (image_model, corrector) in zip(images, correctors): - image_model.meta.cal_step.tweakreg = 'COMPLETE' # retrieve fit status and update wcs if fit is successful: - if ('fit_info' in corrector.meta and - 'SUCCESS' in corrector.meta['fit_info']['status']): + if ("fit_info" in corrector.meta and + "SUCCESS" in corrector.meta["fit_info"]["status"]): # Update/create the WCS .name attribute with information # on this astrometric fit as the only record that it was @@ -509,7 +326,7 @@ def process(self, input): # translated to the FITS WCSNAME keyword # IF that is what gets recorded in the archive # for end-user searches. - corrector.wcs.name = "FIT-LVL3-{}".format(self.abs_refcat) + corrector.wcs.name = f"FIT-LVL3-{self.abs_refcat}" image_model.meta.wcs = corrector.wcs update_s_region_imaging(image_model) @@ -528,12 +345,13 @@ def process(self, input): crpix=None ) except (ValueError, RuntimeError) as e: - self.log.warning( - "Failed to update 'meta.wcsinfo' with FITS SIP " - f'approximation. Reported error is:\n"{e.args[0]}"' - ) + msg = f"Failed to update 'meta.wcsinfo' with FITS SIP \ + approximation. Reported error is: \n {e.args[0]}" + self.log.warning(msg) + image_model.meta.cal_step.tweakreg = "COMPLETE" + + return image_model - return images def _write_catalog(self, catalog, filename): ''' @@ -701,23 +519,3 @@ def _filter_catalog_by_bounding_box(catalog, bounding_box): def _wcs_to_skycoord(wcs): ra, dec = wcs.footprint(axis_type="spatial").T return SkyCoord(ra=ra, dec=dec, unit="deg") - - -def _construct_wcs_corrector(image_model, catalog): - # pre-compute skycoord here so we can later use it - # to check for a small wcs correction - wcs = image_model.meta.wcs - refang = image_model.meta.wcsinfo.instance - return JWSTWCSCorrector( - wcs=image_model.meta.wcs, - wcsinfo={'roll_ref': refang['roll_ref'], - 'v2_ref': refang['v2_ref'], - 'v3_ref': refang['v3_ref']}, - # catalog and group_id are required meta - meta={ - 'catalog': catalog, - 'name': catalog.meta.get('name'), - 'group_id': image_model.meta.group_id, - 'original_skycoord': _wcs_to_skycoord(wcs), - } - ) diff --git a/jwst/tweakreg/utils.py b/jwst/tweakreg/utils.py index 9b3fbaf6e4e..6f4c97f8c3c 100644 --- a/jwst/tweakreg/utils.py +++ b/jwst/tweakreg/utils.py @@ -1,6 +1,5 @@ from copy import deepcopy -from astropy.modeling.rotations import RotationSequence3D from astropy import units from gwcs.wcs import WCS import numpy as np @@ -8,8 +7,9 @@ from tweakwcs.linearfit import build_fit_matrix from stdatamodels.jwst.datamodels import ImageModel -from ..assign_wcs.util import update_fits_wcsinfo +from stcal.tweakreg.utils import _wcsinfo_from_wcs_transform from ..assign_wcs.pointing import _v23tosky +from ..assign_wcs.util import update_fits_wcsinfo _RAD2ARCSEC = 3600.0 * np.rad2deg(1.0) @@ -18,45 +18,6 @@ __all__ = ["adjust_wcs", "transfer_wcs_correction"] -def _wcsinfo_from_wcs_transform(wcs): - frames = wcs.available_frames - if 'v2v3' not in frames or 'world' not in frames or frames[-1] != 'world': - raise ValueError( - "Unsupported WCS structure." - ) - - # Initially get v2_ref, v3_ref, and roll_ref from - # the v2v3 to world transform. Also get ra_ref, dec_ref - t = wcs.get_transform(frames[-2], 'world') - for m in t: - if isinstance(m, RotationSequence3D) and m.parameters.size == 5: - v2_ref, nv3_ref, roll_ref, dec_ref, nra_ref = m.angles.value - break - else: - raise ValueError( - "Unsupported WCS structure." - ) - - # overwrite v2_ref, v3_ref, and roll_ref with - # values from the tangent plane when available: - if 'v2v3corr' in frames: - # get v2_ref, v3_ref, and roll_ref from - # the v2v3 to v2v3corr transform: - frm1 = 'v2v3vacorr' if 'v2v3vacorr' in frames else 'v2v3' - tpcorr = wcs.get_transform(frm1, 'v2v3corr') - v2_ref, nv3_ref, roll_ref = tpcorr['det_to_optic_axis'].angles.value - - wcsinfo = { - 'v2_ref': 3600 * v2_ref, - 'v3_ref': -3600 * nv3_ref, - 'roll_ref': roll_ref, - 'ra_ref': -nra_ref, - 'dec_ref': dec_ref - } - - return wcsinfo - - def adjust_wcs(wcs, delta_ra=0.0, delta_dec=0.0, delta_roll=0.0, scale_factor=1.0): """ diff --git a/pyproject.toml b/pyproject.toml index 11dcffc0772..aed53f0e1ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "scikit-image>=0.19", "scipy>=1.9.3", "spherical-geometry>=1.2.22", - "stcal>=1.7.1,<1.8.0", + "git+https://github.com/emolter/stcal.git@AL-850", "stdatamodels>=2.0.0,<2.1.0", "stpipe>=0.6.0,<0.7.0", "stsci.image>=2.3.5",