diff --git a/CHANGES.rst b/CHANGES.rst index a7633462e0..ac9ea231e8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -49,6 +49,9 @@ extract_1d - Added a hook to bypass the ``extract_1d`` step for NIRISS SOSS data in the FULL subarray with warning. [#8225] +- Fixed a bug in the ATOCA matrix solve for NIRISS SOSS that would cause failures on + good input data in some cases. [#8273] + - Added a trap in the NIRISS SOSS ATOCA algorithm for cases where nearly all pixels in the 2nd-order spectrum are flagged and would cause the step to fail. [#8265] diff --git a/jwst/extract_1d/soss_extract/atoca.py b/jwst/extract_1d/soss_extract/atoca.py index 1b8578a269..b53c668b25 100644 --- a/jwst/extract_1d/soss_extract/atoca.py +++ b/jwst/extract_1d/soss_extract/atoca.py @@ -12,8 +12,9 @@ # General imports. import numpy as np +import warnings from scipy.sparse import issparse, csr_matrix, diags -from scipy.sparse.linalg import spsolve +from scipy.sparse.linalg import spsolve, lsqr, MatrixRankWarning from scipy.interpolate import interp1d # Local imports. @@ -1290,7 +1291,15 @@ def _solve(matrix, result): # Only solve for valid indices, i.e. wavelengths that are # covered by the pixels on the detector. # It will be a singular matrix otherwise. - sln[idx] = spsolve(matrix[idx, :][:, idx], result[idx]) + warnings.filterwarnings(action='error', category=MatrixRankWarning) + try: + sln[idx] = spsolve(matrix[idx, :][:, idx], result[idx]) + except MatrixRankWarning: + # on rare occasions spsolve's approximation of the matrix is not appropriate + # and fails on good input data. revert to different solver + log.info('ATOCA matrix solve failed with spsolve. Retrying with least-squares.') + sln[idx] = lsqr(matrix[idx, :][:, idx], result[idx])[0] + warnings.resetwarnings() return sln diff --git a/jwst/regtest/test_niriss_soss.py b/jwst/regtest/test_niriss_soss.py index b900b0318f..efb2253194 100644 --- a/jwst/regtest/test_niriss_soss.py +++ b/jwst/regtest/test_niriss_soss.py @@ -126,7 +126,24 @@ def test_niriss_soss_extras(rtdata_module, run_atoca_extras, fitsdiff_default_kw assert diff.identical, diff.report() -@pytest.mark.bigdata +@pytest.fixture(scope='module') +def run_extract1d_spsolve_failure(jail, rtdata_module): + """ + Test coverage for fix to error thrown when spsolve fails to find + a good solution in ATOCA and needs to be replaced with a least- + squares solver. Note this failure is architecture-dependent + and also only trips for specific values of the transform parameters. + Pin tikfac for faster runtime. + """ + rtdata = rtdata_module + rtdata.get_data("niriss/soss/jw04098007001_04101_00001-seg003_nis_int01.fits") + args = ["extract_1d", rtdata.input, + "--soss_tikfac=3.1881637371089252e-15", + "--soss_transform=-0.00038201755227297866, -0.24237455427848956, 0.5404013401742825", + ] + Step.from_cmdline(args) + + @pytest.fixture(scope='module') def run_extract1d_null_order2(jail, rtdata_module): """ @@ -144,8 +161,23 @@ def run_extract1d_null_order2(jail, rtdata_module): Step.from_cmdline(args) +@pytest.mark.bigdata +def test_extract1d_spsolve_failure(rtdata_module, run_extract1d_spsolve_failure, fitsdiff_default_kwargs): + + rtdata = rtdata_module + + output = "jw04098007001_04101_00001-seg003_nis_int01_extract1dstep.fits" + rtdata.output = output + + rtdata.get_truth(f"truth/test_niriss_soss_stages/{output}") + + diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs) + assert diff.identical, diff.report() + + @pytest.mark.bigdata def test_extract1d_null_order2(rtdata_module, run_extract1d_null_order2, fitsdiff_default_kwargs): + rtdata = rtdata_module output = "jw01201008001_04101_00001-seg003_nis_int72_extract1dstep.fits"