From f34f5187e7bb7a9d49ae612f20a324cdc8ca19ac Mon Sep 17 00:00:00 2001 From: James Davies Date: Mon, 4 Jan 2021 11:33:48 -0500 Subject: [PATCH] Fix distortion reffile schema and unit test (#5553) * Fix distortion reffile schema and unit test * Add tests and update validation warning reporting for all classes --- CHANGES.rst | 3 + jwst/assign_wcs/tests/test_schemas.py | 91 ++++++++++-- .../datamodels/schemas/distortion.schema.yaml | 1 + jwst/datamodels/wcs_ref_models.py | 132 +++++++++--------- 4 files changed, 153 insertions(+), 74 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index a6a819ae32..d767947db9 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -99,6 +99,9 @@ datamodels - Update Moving Target CHEBY table extension schema for changes to column definitions in the JWSTKD and SDP [#5558] +- Update distortion reference file schema to have ``meta.instrument.channel`` + keyword [#5553] + extract_1d ---------- diff --git a/jwst/assign_wcs/tests/test_schemas.py b/jwst/assign_wcs/tests/test_schemas.py index d30b0870ff..72914fa191 100644 --- a/jwst/assign_wcs/tests/test_schemas.py +++ b/jwst/assign_wcs/tests/test_schemas.py @@ -1,13 +1,29 @@ +import inspect +import sys + from astropy.modeling import models from astropy import units as u -from jwst.datamodels import DistortionModel +import pytest +from jwst.datamodels import DistortionModel, ReferenceFileModel +from jwst.datamodels import wcs_ref_models +from jwst.datamodels.wcs_ref_models import _SimpleModel -def test_distortion_schema(tmpdir): - """Make sure DistortionModel roundtrips""" + +def find_all_wcs_ref_models_classes(): + clsmembers = inspect.getmembers(sys.modules[wcs_ref_models.__name__], inspect.isclass) + classes = [cls for name,cls in clsmembers if issubclass(cls, ReferenceFileModel)] + classes.remove(_SimpleModel) + return classes + + +@pytest.fixture +def distortion_model(): + """Create a distortion model that should pass all validation""" m = models.Shift(1) & models.Shift(2) dist = DistortionModel(model=m, input_units=u.pixel, output_units=u.arcsec) + dist.meta.reftype = "distortion" dist.meta.instrument.name = "NIRCAM" dist.meta.instrument.detector = "NRCA1" dist.meta.instrument.p_pupil = "F162M|F164N|CLEAR|" @@ -16,13 +32,68 @@ def test_distortion_schema(tmpdir): dist.meta.exposure.type = "NRC_IMAGE" dist.meta.psubarray = "FULL|SUB64P|SUB160)|SUB160P|SUB320|SUB400P|SUB640|" dist.meta.subarray.name = "FULL" + + # Populate the following so that no validation warnings or errors happen + dist.meta.instrument.module = "A" + dist.meta.instrument.channel = "SHORT" + dist.meta.input_units = u.degree + dist.meta.output_units = u.degree + dist.meta.description = "NIRCam distortion reference file" + dist.meta.author = "Hank the Septopus" + dist.meta.pedigree = "Cleveland" + dist.meta.useafter = "2000-01-01T00:00:00" + + return dist + + +def test_distortion_schema(distortion_model, tmpdir): + """Make sure DistortionModel roundtrips""" path = str(tmpdir.join("test_dist.asdf")) + dist = distortion_model dist.save(path) - with DistortionModel(path) as dist1: - assert dist1.meta.instrument.p_pupil == dist.meta.instrument.p_pupil - assert dist1.meta.instrument.pupil == dist.meta.instrument.pupil - assert dist1.meta.exposure.p_exptype == dist.meta.exposure.p_exptype - assert dist1.meta.exposure.type == dist.meta.exposure.type - assert dist1.meta.psubarray == dist.meta.psubarray - assert dist1.meta.subarray.name == dist.meta.subarray.name + with pytest.warns(None) as report: + with DistortionModel(path) as dist1: + assert dist1.meta.instrument.p_pupil == dist.meta.instrument.p_pupil + assert dist1.meta.instrument.pupil == dist.meta.instrument.pupil + assert dist1.meta.exposure.p_exptype == dist.meta.exposure.p_exptype + assert dist1.meta.exposure.type == dist.meta.exposure.type + assert dist1.meta.psubarray == dist.meta.psubarray + assert dist1.meta.subarray.name == dist.meta.subarray.name + assert len(report) == 0 + + +def test_distortion_strict_validation(distortion_model): + """Make sure strict validation works""" + distortion_model.validate() + + +def test_distortion_schema_bad_valueerror(distortion_model): + """Check that ValueError is raised for ReferenceFile missing items""" + dist = DistortionModel(distortion_model, strict_validation=True) + dist.meta.author = None + + with pytest.raises(ValueError): + dist.validate() + + +def test_distortion_schema_bad_assertionerror(distortion_model): + """Check that AssertionError is raised for distortion-specific missing items""" + dist = DistortionModel(distortion_model, strict_validation=True) + dist.meta.instrument.channel = None + + with pytest.raises(AssertionError): + dist.validate() + + +@pytest.mark.parametrize("cls", find_all_wcs_ref_models_classes()) +def test_simplemodel_subclasses(cls): + """Test that expected validation errors are raised""" + model = cls() + with pytest.warns(None) as report: + model.validate() + assert len(report) >= 1 + + model = cls(strict_validation=True) + with pytest.raises((ValueError, KeyError)): + model.validate() diff --git a/jwst/datamodels/schemas/distortion.schema.yaml b/jwst/datamodels/schemas/distortion.schema.yaml index 53e5059ad2..27bf45aab4 100644 --- a/jwst/datamodels/schemas/distortion.schema.yaml +++ b/jwst/datamodels/schemas/distortion.schema.yaml @@ -10,6 +10,7 @@ allOf: - $ref: keyword_pupil.schema - $ref: keyword_ppupil.schema - $ref: keyword_module.schema +- $ref: keyword_channel.schema - $ref: keyword_exptype.schema - $ref: keyword_pexptype.schema - $ref: subarray.schema diff --git a/jwst/datamodels/wcs_ref_models.py b/jwst/datamodels/wcs_ref_models.py index 6412baf1f8..4378799cf7 100644 --- a/jwst/datamodels/wcs_ref_models.py +++ b/jwst/datamodels/wcs_ref_models.py @@ -1,5 +1,7 @@ -import numpy as np +import traceback import warnings + +import numpy as np from astropy.modeling.core import Model from astropy import units as u from stdatamodels.validate import ValidationWarning @@ -48,15 +50,15 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") def validate(self): - super(_SimpleModel, self).validate() + super().validate() try: assert isinstance(self.model, Model) assert self.meta.instrument.name in ["NIRCAM", "NIRSPEC", "MIRI", "TFI", "FGS", "NIRISS"] - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) class DistortionModel(_SimpleModel): """ @@ -66,7 +68,7 @@ class DistortionModel(_SimpleModel): reftype = "distortion" def validate(self): - super(DistortionModel, self).validate() + super().validate() try: assert isinstance(self.meta.input_units, (str, u.NamedUnit)) assert isinstance(self.meta.output_units, (str, u.NamedUnit)) @@ -74,11 +76,11 @@ def validate(self): assert self.meta.instrument.module is not None assert self.meta.instrument.channel is not None assert self.meta.instrument.p_pupil is not None - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) class DistortionMRSModel(ReferenceFileModel): """ @@ -90,7 +92,7 @@ class DistortionMRSModel(ReferenceFileModel): def __init__(self, init=None, x_model=None, y_model=None, alpha_model=None, beta_model=None, bzero=None, bdel=None, input_units=None, output_units=None, **kwargs): - super(DistortionMRSModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if x_model is not None: self.x_model = x_model @@ -127,7 +129,7 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") def validate(self): - super(DistortionMRSModel, self).validate() + super().validate() try: assert isinstance(self.meta.input_units, (str, u.NamedUnit)) assert isinstance(self.meta.output_units, (str, u.NamedUnit)) @@ -142,11 +144,12 @@ def validate(self): assert all([isinstance(m, Model) for m in self.beta_model]) assert len(self.abv2v3_model.model) == 2 assert len(self.abv2v3_model.channel_band) == 2 - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) + class SpecwcsModel(_SimpleModel): """ @@ -162,7 +165,7 @@ class SpecwcsModel(_SimpleModel): reftype = "specwcs" def validate(self): - super(SpecwcsModel, self).validate() + super().validate() try: assert isinstance(self.meta.input_units, (str, u.NamedUnit)) assert isinstance(self.meta.output_units, (str, u.NamedUnit)) @@ -172,11 +175,12 @@ def validate(self): "TFI", "FGS", "NIRISS"] - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) + class NIRCAMGrismModel(ReferenceFileModel): """ @@ -216,7 +220,7 @@ def __init__(self, init=None, invdispy=None, orders=None, **kwargs): - super(NIRCAMGrismModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if init is None: self.populate_meta() @@ -241,18 +245,18 @@ def populate_meta(self): self.meta.reftype = self.reftype def validate(self): - super(NIRCAMGrismModel, self).validate() + super().validate() try: assert isinstance(self.meta.input_units, (str, u.NamedUnit)) assert isinstance(self.meta.output_units, (str, u.NamedUnit)) assert self.meta.instrument.name == "NIRCAM" assert self.meta.exposure.type == "NRC_WFSS" assert self.meta.reftype == self.reftype - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") @@ -293,7 +297,7 @@ def __init__(self, init=None, orders=None, fwcpos_ref=None, **kwargs): - super(NIRISSGrismModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if init is None: self.populate_meta() @@ -325,11 +329,11 @@ def validate(self): assert self.meta.exposure.type == "NIS_WFSS" assert self.meta.instrument.detector == "NIS" assert self.meta.reftype == self.reftype - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") @@ -343,7 +347,7 @@ class RegionsModel(ReferenceFileModel): reftype = "regions" def __init__(self, init=None, regions=None, **kwargs): - super(RegionsModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if regions is not None: self.regions = regions if init is None: @@ -360,7 +364,7 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") def validate(self): - super(RegionsModel, self).validate() + super().validate() try: assert isinstance(self.regions, np.ndarray) assert self.meta.instrument.name == "MIRI" @@ -368,11 +372,11 @@ def validate(self): assert self.meta.instrument.channel in ("12", "34", "1", "2", "3", "4") assert self.meta.instrument.band in ("SHORT", "LONG") assert self.meta.instrument.detector in ("MIRIFUSHORT", "MIRIFULONG") - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) class WavelengthrangeModel(ReferenceFileModel): @@ -400,7 +404,7 @@ class WavelengthrangeModel(ReferenceFileModel): def __init__(self, init=None, wrange_selector=None, wrange=None, order=None, extract_orders=None, wunits=None, **kwargs): - super(WavelengthrangeModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if wrange_selector is not None: self.waverange_selector = wrange_selector if wrange is not None: @@ -419,14 +423,14 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file") def validate(self): - super(WavelengthrangeModel, self).validate() + super().validate() try: assert self.meta.instrument.name in ("MIRI", "NIRSPEC", "NIRCAM", "NIRISS") - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) def get_wfss_wavelength_range(self, filter, orders): """ Retrieve the wavelength range for a WFSS observation. @@ -459,7 +463,7 @@ class FPAModel(ReferenceFileModel): def __init__(self, init=None, nrs1_model=None, nrs2_model=None, **kwargs): - super(FPAModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if nrs1_model is not None: self.nrs1_model = nrs1_model if nrs2_model is not None: @@ -482,15 +486,15 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") def validate(self): - super(FPAModel, self).validate() + super().validate() try: assert isinstance(self.nrs1_model, Model) assert isinstance(self.nrs2_model, Model) - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) class IFUPostModel(ReferenceFileModel): @@ -517,7 +521,7 @@ class IFUPostModel(ReferenceFileModel): def __init__(self, init=None, slice_models=None, **kwargs): - super(IFUPostModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if slice_models is not None: if len(slice_models) != 30: raise ValueError("Expected 30 slice models, got {0}".format(len(slice_models))) @@ -540,7 +544,7 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") def validate(self): - super(IFUPostModel, self).validate() + super().validate() class IFUSlicerModel(ReferenceFileModel): @@ -552,7 +556,7 @@ class IFUSlicerModel(ReferenceFileModel): def __init__(self, init=None, model=None, data=None, **kwargs): - super(IFUSlicerModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if model is not None: self.model = model if data is not None: @@ -573,7 +577,7 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") def validate(self): - super(IFUSlicerModel, self).validate() + super().validate() class MSAModel(ReferenceFileModel): @@ -584,7 +588,7 @@ class MSAModel(ReferenceFileModel): reftype = "msa" def __init__(self, init=None, models=None, data=None, **kwargs): - super(MSAModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if models is not None and data is not None: self.Q1 = {'model': models['Q1'], 'data': data['Q1']} self.Q2 = {'model': models['Q2'], 'data': data['Q2']} @@ -609,7 +613,7 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") def validate(self): - super(MSAModel, self).validate() + super().validate() class DisperserModel(ReferenceFileModel): @@ -623,7 +627,7 @@ def __init__(self, init=None, angle=None, gwa_tiltx=None, gwa_tilty=None, kcoef=None, lcoef=None, tcoef=None, pref=None, tref=None, theta_x=None, theta_y=None,theta_z=None, groovedensity=None, **kwargs): - super(DisperserModel, self).__init__(init=init, **kwargs) + super().__init__(init=init, **kwargs) if groovedensity is not None: self.groovedensity = groovedensity if angle is not None: @@ -670,15 +674,15 @@ def to_fits(self): raise NotImplementedError("FITS format is not supported for this file.") def validate(self): - super(DisperserModel, self).validate() + super().validate() try: assert self.meta.instrument.grating in ["G140H", "G140M", "G235H", "G235M", "G395H", "G395M", "MIRROR", "PRISM"] - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) class FilteroffsetModel(ReferenceFileModel): @@ -689,7 +693,7 @@ class FilteroffsetModel(ReferenceFileModel): reftype = "filteroffset" def __init__(self, init=None, filters=None, instrument=None, **kwargs): - super(FilteroffsetModel, self).__init__(init, **kwargs) + super().__init__(init, **kwargs) if filters is not None: self.filters = filters if instrument is None or instrument not in ("NIRCAM", "MIRI", "NIRISS"): @@ -712,7 +716,7 @@ def populate_meta(self): raise ValueError(f"Unsupported instrument: {self.meta.instrument.name}") def validate(self): - super(FilteroffsetModel, self).validate() + super().validate() instrument_name = self.meta.instrument.name nircam_channels = ["SHORT", "LONG"] @@ -809,15 +813,15 @@ def on_save(self, path=None): self.meta.reftype = self.reftype def validate(self): - super(FOREModel, self).validate() + super().validate() try: assert self.meta.instrument.filter in ["CLEAR", "F070LP", "F100LP", "F110W", "F140X", "F170LP", "F290LP"] - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning) class WaveCorrModel(ReferenceFileModel): @@ -826,7 +830,7 @@ class WaveCorrModel(ReferenceFileModel): schema_url = "http://stsci.edu/schemas/jwst_datamodel/wavecorr.schema" def __init__(self, init=None, apertures=None, **kwargs): - super(WaveCorrModel, self).__init__(init, **kwargs) + super().__init__(init, **kwargs) if apertures is not None: self.apertures = apertures if init is None: @@ -843,12 +847,12 @@ def on_save(self, path=None): self.meta.reftype = self.reftype def validate(self): - super(WaveCorrModel, self).validate() + super().validate() try: assert self.aperture_names is not None assert self.apertures is not None - except AssertionError as errmsg: + except AssertionError: if self._strict_validation: - raise AssertionError(errmsg) + raise else: - warnings.warn(str(errmsg), ValidationWarning) + warnings.warn(traceback.format_exc(), ValidationWarning)