-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathorb_calculator.py
449 lines (350 loc) · 17.3 KB
/
orb_calculator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
import astropy.units as u
from astropy.modeling.models import Gaussian1D
from astropy.modeling import fitting
from lightkurve import LightCurve
import lmfit
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import zscore
import seaborn as sns
class OrbCalculator(object):
def __init__(self, lightcurve_data, preload_plots):
self.lightcurve_data = lightcurve_data
self.preload = preload_plots
# Initialize a boolean to determine if the period is real
self.is_real_period = False
# Determine if the period is plausible
self.is_plausible, self.cutoff = self.plausible_period()
# Remove eclipses
self.no_eclipse_flux = self.remove_eclipses()
# Fit a sine wave to the lightcurve
self.sine_fit = self.fit_sine_wave(self.lightcurve_data.time, self.no_eclipse_flux)
# Fold and bin lightcurve
self.binned_lightcurve = self.fold_lightcurve()
# Fold and bin sine fit
self.binned_sine, self.sine_period = self.fold_sine_wave(self.lightcurve_data.time, self.sine_fit.params['frequency'].value, self.sine_fit.best_fit)
# Calculate time points of the sine wave
self.time_points = np.arange(min(self.lightcurve_data.time), max(self.lightcurve_data.time), self.sine_period)
# Calculate plot xmin and xmax
self.xmin = min(self.lightcurve_data.time) + 1 + self.lightcurve_data.period_at_max_power
self.xmax = min(self.lightcurve_data.time) + 1 + 4 * self.lightcurve_data.period_at_max_power
# Create plots for determining if the period is real
self.is_real_period_plot()
# Either save or show plot depending on preload
if preload_plots.preload:
preload_plots.save_plot('Period', lightcurve_data.name)
else:
plt.show()
def plausible_period(self):
"""
Determines if the period at max power of a periodogram is plausible based off of standard deviation
Name: is_real_period()
Parameters:
periodogram: periodogram of the lightcurve
Returns:
boolean: True if the period is plausible, False if not real
cutoff: 5 sigma line cut off
"""
# Remove NaNs from the periodogram
nan_mask = ~np.isnan(self.lightcurve_data.periodogram.power)
periodogram_power = self.lightcurve_data.periodogram.power[nan_mask]
periodogram_period = self.lightcurve_data.periodogram.period[nan_mask].value
# Calculate standard deviation of the periodogram
std_dev = np.std(periodogram_power)
# Check if period at max power is greater than 5 sigma
if abs(self.lightcurve_data.period_at_max_power - np.median(periodogram_period)) > 5 * std_dev:
return True, 5 * std_dev
else:
return False, 5 * std_dev
def create_gaussian_model(self, time_start, time_end, num_gaussians=75):
"""
"""
time_steps = (time_end - time_start) / num_gaussians
model_profiles = []
try:
for i in range(num_gaussians):
mask = (self.lightcurve_data.time > min(self.lightcurve_data.time) + i * time_steps) & (
self.lightcurve_data.time < min(self.lightcurve_data.time) + (i+1) * time_steps)
init_amp = np.max(self.lightcurve_data.flux[mask])
init_mean = np.mean(self.lightcurve_data.time[mask])
init_stddev = np.std(self.lightcurve_data.time[mask])
# Gaussian profile
gaussian_profile = Gaussian1D(amplitude = init_amp, mean = init_mean, stddev = init_stddev)
model_profiles.append(gaussian_profile)
except ValueError:
return
# Create compound model
compound_model = model_profiles[0]
for profile in model_profiles[1:]:
compound_model += profile
# Create a fitter object
fitter = fitting.LevMarLSQFitter()
# Fit the model to the data
fitted_model = fitter(compound_model, self.lightcurve_data.time, self.lightcurve_data.flux)
return fitted_model
def find_sig_eclipses(self, fitted_model, z_threshold = 2):
"""
"""
# Extract amplitudes from fitted Gaussians
amplitudes = np.array([profile.amplitude.value for profile in fitted_model])
# Calculate Z-scores
z_scores = zscore(amplitudes)
# Identify significant eclipses
significant_eclipses = [(i, amplitudes[i], fitted_model[i].mean.value, fitted_model[i].stddev.value)
for i in range(len(z_scores)) if np.abs(z_scores[i]) > z_threshold]
return significant_eclipses
def remove_eclipses(self):
"""
"""
# Time start and end for finding eclipses
time_start = min(self.lightcurve_data.time)
time_end = min(self.lightcurve_data.time) + 1 * self.lightcurve_data.period_at_max_power
# Create the gaussian model
fitted_model = self.create_gaussian_model(time_start, time_end)
# Check if the fitted model is None
if fitted_model is None:
return self.lightcurve_data.flux
# Find significant eclipses
significant_eclipses = self.find_sig_eclipses(fitted_model)
# Isolate eclise means and standard deviatoins
eclipse_means = [mean for _, _, mean, _ in significant_eclipses]
eclipse_stddevs = [stddev for _, _, _, stddev in significant_eclipses]
# Find average mean and standard deviation
eclipse_stddev = np.mean(eclipse_stddevs)
# Replace eclipses with the average flux
avg_flux = np.mean(self.lightcurve_data.flux)
# Replace every eclipse in the lightcurve with the average
no_eclipse_flux = np.copy(self.lightcurve_data.flux)
while eclipse_means[len(eclipse_means) - 1] < max(self.lightcurve_data.time):
# Iterate through each eclipse
for eclipse in eclipse_means:
# Isolate the eclipse
eclipse_mask = (self.lightcurve_data.time > eclipse - 2 * eclipse_stddev) & (
self.lightcurve_data.time < eclipse + 2 * eclipse_stddev)
no_eclipse_flux[eclipse_mask] = avg_flux
eclipse_means = [mean + self.lightcurve_data.period_at_max_power for mean in eclipse_means]
return no_eclipse_flux
def sine_wave(self, x, amplitude, frequency, phase):
"""
Creates a sine wave based off of given parameters
Name: sine_wave()
Parameters:
x: data points
amplitude: desired amplitude
frequency: desired frequency
phase: desired phase
Returns:
a sine wave
"""
return amplitude * np.sin((2 * np.pi * frequency * x) + phase)
def find_bin_value(self, lightcurve, num_bins):
"""
Calculates the best bin value based off of the duration of the lightcurve
Name: find_bin_value()
Parameters:
num_bins: desired number of bins
Returns:
bin_value: number of minutes for each bin
"""
total_points = len(lightcurve.time.value)
total_duration_mins = ((lightcurve.time.value[total_points - 1] - lightcurve.time.value[0]) * u.day).to(u.minute)
bin_value = (total_duration_mins / num_bins).value
return bin_value
def fit_sine_wave(self, time, flux):
"""
Fits a sine wave to a lightcurve using the lmfit package
Name: fit_sine_wave()
Parameters:
time: time data for the lightcurve
flux: flux data for the lightcurve
Returns:
result: fitted sine wave
"""
# Make an lmfit object and fit it
model = lmfit.Model(self.sine_wave)
params = model.make_params(amplitude = self.lightcurve_data.periodogram.max_power,
frequency = 1 / self.lightcurve_data.period_at_max_power,
phase = 0.0)
result = model.fit(flux, params, x = time)
return result
def fold_lightcurve(self, num_folds=1):
"""
Folds the lightcurve on the period at max power, and bins it into 50 bins
Name: fold_lightcurve()
Parameters:
num_folds: number of folds wanted to do on the period (default = 1, just folding on the period at max power)
Returns:
binned_lightcurve: folded and binned lightcurve
"""
# Fold lightcurve
folded_lightcurve = self.lightcurve_data.lightcurve.fold(period=num_folds * self.lightcurve_data.period_at_max_power)
# Calculate bin value
bin_value = self.find_bin_value(folded_lightcurve, num_folds * 100)
# Bin the folded lightcurve
binned_lightcurve = folded_lightcurve.bin(bin_value * u.min)
return binned_lightcurve
def fold_sine_wave(self, x, frequency, sine_wave, num_folds=1):
"""
Folds the fitted sine wave on its period and bins it into 50 bins
Name: find_bin_value()
Parameters:
x: time data of the lightcurve
frequency: frequency of the fitted sine wave
sine_wave: fitted sine wave (best_fit)
Returns:
binned_sine: folded and binned sine wave
sine_period: period of the fitted sine wave
"""
# Calculate the time points for the period lines
sine_period = 1 / frequency
# Make the sine wave into a lightcurve
sine_lightcurve = LightCurve(time=x, flux=sine_wave)
# Fold sine wave
folded_sine = sine_lightcurve.fold(period=num_folds * self.lightcurve_data.period_at_max_power)
# Calculate bin value
bin_value = self.find_bin_value(folded_sine, num_folds * 50)
binned_sine = folded_sine.bin(bin_value * u.min)
return binned_sine, sine_period
def plot_periodogram(self, axis):
"""
Plots the lightcurve's periodogram on a given axis, as well as the period at max power, and the literature
period, if any
Name: plot_periodogram()
Parameters:
axis: axis to be plotted on
Returns:
None
"""
# Plot title
axis.set_title('Periodogram', fontsize=12)
axis.set_xlabel(r'$P_{\text{orb}}$ (days)', fontsize=10)
axis.set_ylabel('Power', fontsize=10)
axis.plot(self.lightcurve_data.periodogram.period, self.lightcurve_data.periodogram.power, color='#9AADD0')
axis.axvline(x=self.lightcurve_data.period_at_max_power, color="#101935", ls=(0, (4, 5)), lw=2,
label=fr'$P_{{\text{{orb, max power}}}}={np.round(self.lightcurve_data.period_at_max_power, 3)}$ days')
# Plot literature period if there is one
if self.lightcurve_data.lit_period != 0.0:
axis.axvline(x=self.lightcurve_data.lit_period, color='#A30015',
label=fr'Literature $P_{{\text{{orb}}}}={np.round(self.lightcurve_data.lit_period, 3)}$ days')
# Plot 5 sigma cutoff
axis.axhline(y=self.cutoff, color='#4A5D96', ls=(0, (4, 5)), lw=2, label='5-sigma cutoff')
# Change scale to be log
axis.set_xscale('log')
# Add legend
axis.legend(loc='upper left')
def plot_binned_lightcurve(self, axis):
"""
Plots the binned lightcurve and the binned sine wave on a given axis
Name: plot_binned_lightcurve()
Parameters:
axis: axis to be plotted on
num_folds: number of folds wanted to fold the period on (default = 1)
Returns:
None
"""
# Plot title
axis.set_title(r'Lightcurve Folded on $P_{\text{orb, max power}}$', fontsize=12)
axis.set_xlabel('Phase', fontsize=10)
axis.set_ylabel('Normalized Flux', fontsize=10)
# Plot the binned lightcurve
axis.vlines(self.binned_lightcurve.phase.value,
self.binned_lightcurve.flux - self.binned_lightcurve.flux_err,
self.binned_lightcurve.flux + self.binned_lightcurve.flux_err, color='#9AADD0', lw=2)
# Plot the binned sine fit
axis.plot(self.binned_sine.phase.value, self.binned_sine.flux.value, color='#101935', label='Folded Sine Wave')
# Add legend
axis.legend(loc='upper right')
def plot_lightcurve_and_sine(self, axis):
"""
Plots the lightcurve and the sine wave, as well as the period of the sine wave
Name: plot_lightcurve_and_sine()
Parameters:
axis: axis to be plotted on
Returns:
None
"""
# Plot title
axis.set_title('Lightcurve', fontsize=12)
axis.set_xlabel('Time (days)', fontsize=10)
axis.set_ylabel('Normalized Flux', fontsize=10)
# Plot lightcurve
axis.vlines(self.lightcurve_data.lightcurve.time.value,
self.lightcurve_data.lightcurve.flux - self.lightcurve_data.lightcurve.flux_err,
self.lightcurve_data.lightcurve.flux + self.lightcurve_data.lightcurve.flux_err, color='#9AADD0')
# Add vertical lines at each period interval of the sine wave
for tp in self.time_points:
axis.axvline(x = tp, color = '#4A5D96', ls = (0, (4, 5)), lw = 2,
label = fr'$P_{{\text{{orb, sine}}}} = {np.round(self.sine_period, 3)}$ days' if tp == self.time_points[0] else "")
# Plot sine wave
axis.plot(self.lightcurve_data.time, self.sine_fit.best_fit, color='#101935', label='Fitted Sine Wave')
# Set xlim and plot legend
axis.set_xlim(self.xmin, self.xmax)
axis.legend(loc='upper right')
def plot_residuals(self, axis):
"""
Plots the residuals of the lightcurve, which is the flux subtracted by the sine fit
Name: plot_residuals()
Parameters:
axis: axis to be plotted on
Returns:
None
"""
# Calculate residuals (lightcurve flux - sine wave flux)
residuals = self.lightcurve_data.flux - self.sine_fit.best_fit
# Plot title
axis.set_title('Flux - Fitted Sine Wave', fontsize=12)
axis.set_xlabel('Time (days)', fontsize=10)
axis.set_ylabel('Normalized Flux', fontsize=10)
# Plot the residuals
axis.plot(self.lightcurve_data.time, residuals, color='#9AADD0')
# Set xlim (no legend needed)
axis.set_xlim(self.xmin, self.xmax)
def is_real_period_plot(self):
"""
Present a plot of the periodogram, binned lightcurve, lightcurve, and residuals, which are then used to
determine if the period at max power is real or not
Name: is_real_period_plot()
Parameters:
None
Returns:
None
"""
# Plot basics
sns.set_style("whitegrid")
# sns.set_theme(rc={'axes.facecolor': '#F8F5F2'})
fig, axs = plt.subplots(2, 2, figsize=(14, 8))
plt.subplots_adjust(hspace=0.35)
plt.suptitle(fr"Press 'y' if the period is real, 'n' if not.", fontweight='bold')
if self.is_plausible:
fig.text(0.5, 0.928, r'Note: $P_{\text{orb, max power}}$ is over 5 sigma, so MIGHT be real', ha='center', fontsize=12, style='italic')
else:
fig.text(0.5, 0.928, r'Note: $P_{\text{orb, max power}}$ is under 5 sigma, so might NOT be real', ha='center', fontsize=12, style='italic')
fig.text(0.5, 0.05, f'{self.lightcurve_data.name}', ha='center', fontsize=16, fontweight='bold')
fig.text(0.5, 0.02, fr'$i_{{\text{{mag}}}}={self.lightcurve_data.imag}$', ha='center', fontsize=12, fontweight='bold')
cid = fig.canvas.mpl_connect('key_press_event', lambda event: self.on_key(event))
# Plot the periodogram
self.plot_periodogram(axs[0, 0]) # see if can do this
# Plot the binned lightcurve
self.plot_binned_lightcurve(axs[1, 0])
# Plot the lightcurve with the sine fit
self.plot_lightcurve_and_sine(axs[0, 1])
# Plot residuals
self.plot_residuals(axs[1, 1])
def on_key(self, event):
"""
Event function that determines if a key was clicked
Name: on_key()
Parameters:
event: key press event
Returns:
None
"""
y_n_keys = {'y', 'n'}
if event.key not in y_n_keys:
print("Invalid key input, select 'y' or 'n'")
else:
if event.key == 'n':
print('Period is not real, loading next plot ... \n')
else:
self.is_real_period = True
plt.close()