Skip to content

Commit

Permalink
Merge pull request #141 from HealthyPear/fix-energy_scale_in_and_from…
Browse files Browse the repository at this point in the history
…_model

Ensure that estimated energy is always recorded in linear scale
  • Loading branch information
HealthyPear authored May 19, 2021
2 parents fd7b693 + a819c7b commit b1e1574
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 2 deletions.
1 change: 1 addition & 0 deletions protopipe/aux/example_config_files/AdaBoostRegressor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ GridSearchCV:
Method:
name: 'sklearn.ensemble.AdaBoostRegressor'
target_name: 'true_energy'
log_10_target: True # this makes the model use log10(target_name)
# Please, see scikit-learn's API for what each parameter means
# NOTE: null == None
base_estimator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ GridSearchCV:
# Definition of the model algorithm/method used and its hyper-parameters
Method:
name: 'sklearn.ensemble.RandomForestRegressor' # DO NOT CHANGE
target_name: 'log10_true_energy'
target_name: 'true_energy'
log_10_target: True # this makes the model use log10(target_name)
tuned_parameters:
# Please, see scikit-learn's API for what each parameter means
# NOTE: null == None
Expand Down
8 changes: 8 additions & 0 deletions protopipe/scripts/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ def main():

if class_name in model_types["regressor"]:

try:
log_10_target = cfg["Method"]["log_10_target"]
except KeyError:
log_10_target = True

if log_10_target:
target_name = f"log10_{target_name}"

# Get the selection cuts
cuts = make_cut_list(cfg["SigFiducialCuts"])

Expand Down
5 changes: 5 additions & 0 deletions protopipe/scripts/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def main():

# Read configuration file
regressor_config = load_config(args.regressor_config)
log_10_target = regressor_config["Method"]["log_10_target"]

regressor_files = (
args.regressor_dir + "/regressor_{cam_id}_{regressor}.pkl.gz"
Expand Down Expand Up @@ -380,6 +381,10 @@ def main():
energy_tel[idx] = model.predict(features_values)
weight_tel[idx] = data["estimation_weight"]

if log_10_target:
energy_tel[idx] = 10**energy_tel[idx]
weight_tel[idx] = 10**weight_tel[idx]

reco_energy_tel[tel_id] = energy_tel[idx]

reco_energy = np.sum(weight_tel * energy_tel) / sum(weight_tel)
Expand Down
1 change: 1 addition & 0 deletions protopipe/scripts/tests/test_AdaBoostRegressor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ GridSearchCV:
Method:
name: 'sklearn.ensemble.AdaBoostRegressor'
target_name: 'true_energy'
log_10_target: True # this makes the model use log10(target_name)
# Please, see scikit-learn's API for what each parameter means
# NOTE: null == None
base_estimator:
Expand Down
3 changes: 2 additions & 1 deletion protopipe/scripts/tests/test_RandomForestRegressor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ GridSearchCV:
# Definition of the model algorithm/method used and its hyper-parameters
Method:
name: 'sklearn.ensemble.RandomForestRegressor' # DO NOT CHANGE
target_name: 'log10_true_energy'
target_name: 'true_energy'
log_10_target: True # this makes the model use log10(target_name)
tuned_parameters:
# Please, see scikit-learn's API for what each parameter means
# NOTE: null == None
Expand Down
5 changes: 5 additions & 0 deletions protopipe/scripts/write_dl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def main():

# Read configuration file
regressor_config = load_config(args.regressor_config)
log_10_target = regressor_config["Method"]["log_10_target"]

regressor_files = (
args.regressor_dir + "/regressor_{cam_id}_{regressor}.pkl.gz"
Expand Down Expand Up @@ -427,6 +428,10 @@ class RecoEvent(tb.IsDescription):
else:
energy_tel[idx] = np.nan

if log_10_target:
energy_tel[idx] = 10**energy_tel[idx]
weight_tel[idx] = 10**weight_tel[idx]

# Record the values regardless of the validity
# We don't use this now, but it should be recorded
energy_tel_classifier[tel_id] = energy_tel[idx]
Expand Down

0 comments on commit b1e1574

Please sign in to comment.