Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3483: fix atoca failures due to sparse matrix solver #8273

Merged
merged 8 commits into from
Feb 16, 2024
Merged
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ extract_1d
- Added a hook to bypass the ``extract_1d`` step for NIRISS SOSS data in
the FULL subarray with warning. [#8225]

- Fixed a bug in the ATOCA matrix solve for NIRISS SOSS that would cause failures on
good input data in some cases. [#8273]

- Added a trap in the NIRISS SOSS ATOCA algorithm for cases where nearly all
pixels in the 2nd-order spectrum are flagged and would cause the step
to fail. [#8265]
Expand Down
13 changes: 11 additions & 2 deletions jwst/extract_1d/soss_extract/atoca.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

# General imports.
import numpy as np
import warnings
from scipy.sparse import issparse, csr_matrix, diags
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import spsolve, lsqr, MatrixRankWarning
from scipy.interpolate import interp1d

# Local imports.
Expand Down Expand Up @@ -1290,7 +1291,15 @@
# Only solve for valid indices, i.e. wavelengths that are
# covered by the pixels on the detector.
# It will be a singular matrix otherwise.
sln[idx] = spsolve(matrix[idx, :][:, idx], result[idx])
warnings.filterwarnings(action='error', category=MatrixRankWarning)
try:
sln[idx] = spsolve(matrix[idx, :][:, idx], result[idx])
except MatrixRankWarning:

Check warning on line 1297 in jwst/extract_1d/soss_extract/atoca.py

View check run for this annotation

Codecov / codecov/patch

jwst/extract_1d/soss_extract/atoca.py#L1294-L1297

Added lines #L1294 - L1297 were not covered by tests
# on rare occasions spsolve's approximation of the matrix is not appropriate
# and fails on good input data. revert to different solver
log.info('ATOCA matrix solve failed with spsolve. Retrying with least-squares.')
sln[idx] = lsqr(matrix[idx, :][:, idx], result[idx])[0]
warnings.resetwarnings()

Check warning on line 1302 in jwst/extract_1d/soss_extract/atoca.py

View check run for this annotation

Codecov / codecov/patch

jwst/extract_1d/soss_extract/atoca.py#L1300-L1302

Added lines #L1300 - L1302 were not covered by tests

return sln

Expand Down
34 changes: 33 additions & 1 deletion jwst/regtest/test_niriss_soss.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,24 @@ def test_niriss_soss_extras(rtdata_module, run_atoca_extras, fitsdiff_default_kw
assert diff.identical, diff.report()


@pytest.mark.bigdata
@pytest.fixture(scope='module')
def run_extract1d_spsolve_failure(jail, rtdata_module):
"""
Test coverage for fix to error thrown when spsolve fails to find
a good solution in ATOCA and needs to be replaced with a least-
squares solver. Note this failure is architecture-dependent
and also only trips for specific values of the transform parameters.
Pin tikfac for faster runtime.
"""
rtdata = rtdata_module
rtdata.get_data("niriss/soss/jw04098007001_04101_00001-seg003_nis_int01.fits")
args = ["extract_1d", rtdata.input,
"--soss_tikfac=3.1881637371089252e-15",
"--soss_transform=-0.00038201755227297866, -0.24237455427848956, 0.5404013401742825",
]
Step.from_cmdline(args)


@pytest.fixture(scope='module')
def run_extract1d_null_order2(jail, rtdata_module):
"""
Expand All @@ -144,8 +161,23 @@ def run_extract1d_null_order2(jail, rtdata_module):
Step.from_cmdline(args)


@pytest.mark.bigdata
def test_extract1d_spsolve_failure(rtdata_module, run_extract1d_spsolve_failure, fitsdiff_default_kwargs):

rtdata = rtdata_module

output = "jw04098007001_04101_00001-seg003_nis_int01_extract1dstep.fits"
rtdata.output = output

rtdata.get_truth(f"truth/test_niriss_soss_stages/{output}")

diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs)
assert diff.identical, diff.report()


@pytest.mark.bigdata
def test_extract1d_null_order2(rtdata_module, run_extract1d_null_order2, fitsdiff_default_kwargs):

rtdata = rtdata_module

output = "jw01201008001_04101_00001-seg003_nis_int72_extract1dstep.fits"
Expand Down
Loading