Skip to content

Commit

Permalink
Add tests and update validation warning reporting for all classes
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavies-st committed Dec 20, 2020
1 parent 795d1ef commit a3c8100
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 65 deletions.
63 changes: 58 additions & 5 deletions jwst/assign_wcs/tests/test_schemas.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import inspect
import sys

from astropy.modeling import models
from astropy import units as u
import pytest

from jwst.datamodels import DistortionModel
from jwst.datamodels import DistortionModel, ReferenceFileModel
from jwst.datamodels import wcs_ref_models
from jwst.datamodels.wcs_ref_models import _SimpleModel


def test_distortion_schema(tmpdir):
"""Make sure DistortionModel roundtrips"""
def find_all_wcs_ref_models_classes():
clsmembers = inspect.getmembers(sys.modules[wcs_ref_models.__name__], inspect.isclass)
classes = [cls for name,cls in clsmembers if issubclass(cls, ReferenceFileModel)]
classes.remove(_SimpleModel)
return classes


@pytest.fixture
def distortion_model():
"""Create a distortion model that should pass all validation"""
m = models.Shift(1) & models.Shift(2)
dist = DistortionModel(model=m, input_units=u.pixel, output_units=u.arcsec)

dist.meta.reftype = "distortion"
dist.meta.instrument.name = "NIRCAM"
dist.meta.instrument.detector = "NRCA1"
dist.meta.instrument.p_pupil = "F162M|F164N|CLEAR|"
Expand All @@ -29,7 +43,13 @@ def test_distortion_schema(tmpdir):
dist.meta.pedigree = "Cleveland"
dist.meta.useafter = "2000-01-01T00:00:00"

return dist


def test_distortion_schema(distortion_model, tmpdir):
"""Make sure DistortionModel roundtrips"""
path = str(tmpdir.join("test_dist.asdf"))
dist = distortion_model
dist.save(path)

with pytest.warns(None) as report:
Expand All @@ -42,5 +62,38 @@ def test_distortion_schema(tmpdir):
assert dist1.meta.subarray.name == dist.meta.subarray.name
assert len(report) == 0

with DistortionModel(path, strict_validation=True) as dist2:
dist2.validate()

def test_distortion_strict_validation(distortion_model):
"""Make sure strict validation works"""
distortion_model.validate()


def test_distortion_schema_bad_valueerror(distortion_model):
"""Check that ValueError is raised for ReferenceFile missing items"""
dist = DistortionModel(distortion_model, strict_validation=True)
dist.meta.author = None

with pytest.raises(ValueError):
dist.validate()


def test_distortion_schema_bad_assertionerror(distortion_model):
"""Check that AssertionError is raised for distortion-specific missing items"""
dist = DistortionModel(distortion_model, strict_validation=True)
dist.meta.instrument.channel = None

with pytest.raises(AssertionError):
dist.validate()


@pytest.mark.parametrize("cls", find_all_wcs_ref_models_classes())
def test_simplemodel_subclasses(cls):
"""Test that expected validation errors are raised"""
model = cls()
with pytest.warns(None) as report:
model.validate()
assert len(report) >= 1

model = cls(strict_validation=True)
with pytest.raises((ValueError, KeyError)):
model.validate()
Loading

0 comments on commit a3c8100

Please sign in to comment.