Skip to content

Commit

Permalink
solve estimation error
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume Levy committed Jan 25, 2023
1 parent 346693f commit a534f48
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 19 deletions.
4 changes: 2 additions & 2 deletions core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ def compute_bins(pred: np.ndarray, energy: np.ndarray) -> Tuple[np.ndarray, np.n
pred_mean, pred_std, true_mean = [], [], []
for e_bin in enumerate(bins):
indicies = np.where(ind_bins == e_bin[0])[0]
pred_mean.append(np.mean(pred[:, indicies] - energy[indicies]))
pred_std.append(np.sqrt(np.mean(np.std(pred[:, indicies], axis=0)**2)))
pred_mean.append(np.mean((pred[:, indicies] - energy[indicies])/energy[indicies]))
pred_std.append(np.sqrt(np.mean(np.std((pred[:, indicies] - energy[indicies])/energy[indicies], axis=0)**2)))
true_mean.append(np.mean(energy[indicies]))

pred_mean = np.array(pred_mean)
Expand Down
17 changes: 0 additions & 17 deletions core/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,12 @@ def plot_bins_results(train_values: Tuple[np.ndarray, np.ndarray, np.ndarray],

#Plot the results on the bins
plt.clf()
#TODO: Error instead of dividing of taking the mean of (E_pr - e_th)/e_th we did mean(E_pr - E_th)/mean(E_th)
plt.errorbar(true_train_mean, pred_train_mean, yerr=pred_train_std, fmt="o", label="Train")
plt.plot(true_train_mean, [0 for _ in range(len(true_train_mean))], "k")
plt.errorbar(true_test_mean, pred_test_mean, yerr=pred_test_std, fmt="o", label="Val")
plt.plot(true_test_mean, [0 for _ in range(len(true_test_mean))], "k")
plt.title("Results")
plt.xlabel("ground truth energy (EeV)")
plt.ylabel("$E_{pr} - E_{th} (EeV)$")
plt.xlim(0, 4.1)
plt.legend()
if fig_dir is not None:
plt.savefig(fig_dir + "/" + "all")

#Plot the residue
plt.figure()
plt.errorbar(true_train_mean, pred_train_mean/true_train_mean,
yerr=pred_train_std/true_train_mean, fmt="o", label="Train")
plt.plot(true_train_mean, [0 for _ in range(len(true_train_mean))], "k")
plt.errorbar(true_test_mean, pred_test_mean/true_test_mean,
yerr=pred_test_std/true_test_mean, fmt="o", label="Val")
plt.plot(true_test_mean, [0 for _ in range(len(true_test_mean))], "k")
plt.title("Results")
plt.xlabel("ground truth energy (EeV)")
plt.ylabel(r"$\frac{E_{pr} - E_{th}}{E_{th}} $")
plt.xlim(0, 4.1)
plt.legend()
Expand Down

0 comments on commit a534f48

Please sign in to comment.