Skip to content

Commit

Permalink
JP-3321: NRS IRS2 refpix bad pixel updates (#7745)
Browse files Browse the repository at this point in the history
Co-authored-by: Howard Bushouse <[email protected]>
  • Loading branch information
penaguerrero and hbushouse authored Aug 11, 2023
1 parent c607b4e commit 9143aec
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 64 deletions.
9 changes: 8 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ flat_field
----------

- Modify the test_flatfield_step_interface unit test to prevent it from causing
other tests to fail [#7752]
other tests to fail. [#7752]

general
-------
Expand All @@ -51,6 +51,13 @@ pathloss

- Fix interpolation error for point source corrections. [#7799]

refpix
------

- Modified algorithm of intermittent bad pixels factor to be the number
of sigmas away from mean for the corresponding array (either differences,
means, or standard deviations arrays). [#7745]

resample
--------

Expand Down
5 changes: 3 additions & 2 deletions docs/jwst/refpix/arguments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,6 @@ default value is True, and this argument applies to MIR data only.
* ``--ovr_corr_mitigation_ftr``

This is a factor to avoid overcorrection of intermittently bad reference
pixels in the IRS2 algorithm. The default value is 1.8, and this argument
applies only to NIRSpec data taken with IRS2 mode.
pixels in the IRS2 algorithm. This factor is the number of sigmas away
from the mean. The default value is 3.0, and this argument applies
only to NIRSpec data taken with IRS2 mode.
13 changes: 8 additions & 5 deletions docs/jwst/refpix/description.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,14 @@ will be set to zero if they are flagged as bad in the DQ extension.
At this point the algorithm looks for intermittently bad (or suspicious)
reference pixels. This is done by calculating the means and standard
deviations per reference pixel column, as well as the difference between
even and odd pairs; then calculates the mean of each of these arrays (the
mean of the absolute values for the differences array), and flag all
values greater than the corresponding mean times a factor to avoid
overcorrection. All suspicious pixels will be replaced by their
nearest good reference pixel.
even and odd pairs; then calculates the mean and standard deviation of
each of these arrays (the mean of the absolute values for the
differences array), and flag all values greater than the corresponding
mean plus the standard deviation times a factor to avoid overcorrection.
All suspicious pixels will be replaced by their nearest good reference
pixel, or set to zero if there were no good reference pixels left
(although this is unlikely to happen as there are typically only a few
pixels flagged as suspicious).

The next step in this processing is to
copy the science data and the reference pixel data separately to temporary
Expand Down
47 changes: 27 additions & 20 deletions jwst/refpix/irs2_subtract_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def correct_model(input_model, irs2_model,
scipix_n_default=16, refpix_r_default=4, pad=8,
ovr_corr_mitigation_ftr=1.8):
ovr_corr_mitigation_ftr=3.0):
"""Correct an input NIRSpec IRS2 datamodel using reference pixels.
Parameters
Expand All @@ -34,7 +34,8 @@ def correct_model(input_model, irs2_model,
the phase of temporally periodic signals.
ovr_corr_mitigation_ftr: float
Factor to avoid overcorrection of intermittently bad reference pixels
Factor to avoid overcorrection of intermittently bad reference
pixels. This factor is the N sigmas away from the mean.
Returns
-------
Expand Down Expand Up @@ -146,7 +147,7 @@ def correct_model(input_model, irs2_model,

# Compute and apply the correction to one integration at a time
for integ in range(n_int):
log.info(f'Working on integration {integ+1}')
log.info(f'Working on integration {integ+1} out of {n_int}')

# The input data have a length of 3200 for the last axis (X), while
# the output data have an X axis with length 2048, the same as the
Expand Down Expand Up @@ -343,9 +344,9 @@ def clobber_ref(data, output, odd_even, mask, scipix_n=16, refpix_r=4):
for k in bits:
ref = (offset + scipix_n // 2 + k * (scipix_n + refpix_r) +
2 * (odd_even_row - 1))
log.debug("bad interleaved reference at pixels {} {}"
.format(ref, ref + 1))
data[..., ref:ref + 2] = 0.
log.debug("bad interleaved reference at pixels {} through {}"
.format(ref, ref + 4))
data[..., ref:ref + 4] = 0.


def decode_mask(output, mask):
Expand Down Expand Up @@ -377,14 +378,12 @@ def decode_mask(output, mask):

# The bit number corresponds to a count of groups of reads of the
# interleaved reference pixels. The 32-bit unsigned integer encoding
# has increasing index, following the amplifier readout direction.
# has increasing index, from left to right.

flags = np.array([2**n for n in range(32)], dtype=np.uint32)
temp = np.bitwise_and(flags, mask)
bits = np.where(temp > 0)[0]
bits = list(bits)
if output // 2 * 2 == output:
bits = [31 - bit for bit in bits]
bits.sort()

return bits
Expand All @@ -410,7 +409,8 @@ def rm_intermittent_badpix(data, scipix_n, refpix_r, ovr_corr_mitigation_ftr):
regular samples.
ovr_corr_mitigation_ftr: float
Factor to avoid overcorrection of bad intermittent reference pixels
Factor to avoid overcorrection of intermittently bad reference
pixels. This factor is the N sigmas away from the mean.
Returns
-------
Expand All @@ -419,6 +419,9 @@ def rm_intermittent_badpix(data, scipix_n, refpix_r, ovr_corr_mitigation_ftr):
science and interleaved reference pixel values. The intermittently
bad pixels are now set to the nearest good reference pixel value.
"""

log.info('Using overcorrection mitigation factor = {}'.format(ovr_corr_mitigation_ftr))

# The intermittently bad pixels will be replaced for all integrations
# and all groups. The last group will be used to identify them
nints, ngroups, ny, nx = np.shape(data)
Expand Down Expand Up @@ -458,23 +461,26 @@ def rm_intermittent_badpix(data, scipix_n, refpix_r, ovr_corr_mitigation_ftr):
rp2check.append(odd_pix)
pair = 0
diff_m = np.mean(np.abs(diffs))
std_of_diffs = np.std(diffs)
mean_mean = np.mean(rp_means)
std_of_means = np.std(rp_means)
mean_std = np.mean(rp_stds)
std_of_std = np.std(rp_stds)

# order indexes increasing from left to right
rp2check.sort()

# find the additional intermittent bad pixels - the factor is to avoid overcorrection
log.info('Using overcorrection mitigation factor = {}'.format(ovr_corr_mitigation_ftr))
high_diffs = np.where(np.abs(diffs) > ovr_corr_mitigation_ftr * diff_m)[0]
high_diffs = np.where(np.abs(diffs) > ovr_corr_mitigation_ftr*std_of_diffs + diff_m)[0]
hd_rp2replace = []
for j in high_diffs:
rp2r = rp2check[int(diffs.index(diffs[j]) * 2)]
# include both even and odd
hd_rp2replace.append(rp2r)
hd_rp2replace.append(rp2r+1)
high_means_idx = np.where(np.array(rp_means) > ovr_corr_mitigation_ftr * mean_mean)[0]
high_std_idx = np.where(np.array(rp_stds) > ovr_corr_mitigation_ftr * mean_std)[0]
high_means_idx = np.where(np.array(rp_means) > ovr_corr_mitigation_ftr*std_of_means + mean_mean)[0]
high_std_idx = np.where(np.array(rp_stds) > ovr_corr_mitigation_ftr*std_of_std + mean_std)[0]
log.info('high_diffs={} high_means={} high_stds={}'.format(len(high_diffs), len(high_means_idx), len(high_std_idx)))
ref_pix = np.array(ref_pix)
rp2replace = []
for rp in ref_pix:
Expand All @@ -499,12 +505,13 @@ def rm_intermittent_badpix(data, scipix_n, refpix_r, ovr_corr_mitigation_ftr):
remaining_rp = remaining_rp_even
else:
remaining_rp = remaining_rp_odd
if len(remaining_rp) == 0: # claims all ref pix are bad, avoid overcorrection and skip
continue
good_idx = (np.abs(remaining_rp - bad_pix)).argmin()
good_pix = remaining_rp[good_idx]
data[..., bad_pix] = data[..., good_pix]
log.debug(' Pixel {}'.format(bad_pix))
if len(remaining_rp) == 0: # no good ref pix left, let interpolation do it's magic
data[..., bad_pix] = 0.
else:
good_idx = (np.abs(remaining_rp - bad_pix)).argmin()
good_pix = remaining_rp[good_idx]
data[..., bad_pix] = data[..., good_pix]
log.info(' Pixel {}'.format(bad_pix))
log.info('Total intermittent bad reference pixels: {}'.format(len(total_rp2replace)))


Expand Down
4 changes: 2 additions & 2 deletions jwst/refpix/refpix_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class RefPixStep(Step):
side_smoothing_length = integer(default=11)
side_gain = float(default=1.0)
odd_even_rows = boolean(default=True)
ovr_corr_mitigation_ftr = float(default=1.8)
ovr_corr_mitigation_ftr = float(default=3.0)
"""

reference_file_types = ['refpix']
Expand Down Expand Up @@ -53,7 +53,7 @@ def process(self, input):

# Apply the IRS2 correction scheme
result = irs2_subtract_reference.correct_model(input_model, irs2_model,
self.ovr_corr_mitigation_ftr)
ovr_corr_mitigation_ftr=self.ovr_corr_mitigation_ftr)

if result.meta.cal_step.refpix != 'SKIPPED':
result.meta.cal_step.refpix = 'COMPLETE'
Expand Down
52 changes: 19 additions & 33 deletions jwst/refpix/tests/test_clobber_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def test_clobber_ref():
data = np.ones((2, 3, 5, 3200), dtype=np.float32)
data = np.ones((2, 3, 5, 3200))

output = np.array([1, 1, 2, 2, 3, 3, 4, 4], dtype=np.int16)
odd_even = np.array([1, 2, 1, 2, 1, 2, 1, 2], dtype=np.int16)
Expand All @@ -15,38 +15,24 @@ def test_clobber_ref():
2**5 + 2**7,
2**11 + 2**13,
0,
2**4],
0],
dtype=np.uint32)

clobber_ref(data, output, odd_even, mask)

compare = np.ones((2, 3, 5, 3200), dtype=np.float32)
compare[:, :, :, 648] = 0.
compare[:, :, :, 649] = 0.
compare[:, :, :, 668] = 0.
compare[:, :, :, 669] = 0.
compare[:, :, :, 690] = 0.
compare[:, :, :, 691] = 0.
compare[:, :, :, 710] = 0.
compare[:, :, :, 711] = 0.
compare[:, :, :, 1290] = 0.
compare[:, :, :, 1291] = 0.
compare[:, :, :, 1310] = 0.
compare[:, :, :, 1311] = 0.
compare[:, :, :, 1368] = 0.
compare[:, :, :, 1369] = 0.
compare[:, :, :, 1388] = 0.
compare[:, :, :, 1389] = 0.
compare[:, :, :, 2028] = 0.
compare[:, :, :, 2029] = 0.
compare[:, :, :, 2068] = 0.
compare[:, :, :, 2069] = 0.
compare[:, :, :, 2150] = 0.
compare[:, :, :, 2151] = 0.
compare[:, :, :, 2190] = 0.
compare[:, :, :, 2191] = 0.
compare[:, :, :, 3108] = 0.
compare[:, :, :, 3109] = 0.
compare = np.ones((2, 3, 5, 3200))
compare[..., 648: 648+4] = 0.
compare[..., 668: 668+4] = 0.
compare[..., 690: 690+4] = 0.
compare[..., 710: 710+4] = 0.
compare[..., 1890: 1890+4] = 0.
compare[..., 1910: 1910+4] = 0.
compare[..., 1808: 1808+4] = 0.
compare[..., 1828: 1828+4] = 0.
compare[..., 2028: 2028+4] = 0.
compare[..., 2068: 2068+4] = 0.
compare[..., 2150: 2150+4] = 0.
compare[..., 2190: 2190+4] = 0.

assert np.allclose(data, compare)

Expand All @@ -62,12 +48,12 @@ def test_decode_mask():
check = np.zeros(nrows, dtype=bool)
compare = [[5, 20],
[],
[18, 23, 27],
[28],
[4, 8, 13],
[3],
[6, 14],
[9, 25],
[28, 31],
[16, 24, 31]]
[0, 3],
[0, 7, 15]]
for row in range(nrows):
bits = decode_mask(output[row], mask[row])
check[row] = (bits == compare[row])
Expand Down
2 changes: 1 addition & 1 deletion jwst/refpix/tests/test_rm_intermittent_badpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_rm_intermittent_badpix():
data[..., 3128] = 15.

scipix_n, refpix_r = 16, 4
ovr_corr_mitigation_ftr = 1.8
ovr_corr_mitigation_ftr = 3.0
rm_intermittent_badpix(data, scipix_n, refpix_r, ovr_corr_mitigation_ftr)

compare = np.ones((2, 3, 5, 3200), dtype=np.float32)
Expand Down

0 comments on commit 9143aec

Please sign in to comment.