-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in CKF and add script for KF timing test #126
Changes from all commits
cb36dff
d7cddf6
84a8f40
dbb0c9e
70e9d10
3038d5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,6 +189,12 @@ class KalmanFitter { | |
public: | ||
/// Shorthand definition | ||
using MeasurementSurfaces = std::multimap<const Layer*, const Surface*>; | ||
/// The navigator type | ||
using KalmanNavigator = typename propagator_t::Navigator; | ||
|
||
/// The navigator has DirectNavigator type or not | ||
static constexpr bool isDirectNavigator = | ||
std::is_same<KalmanNavigator, DirectNavigator>::value; | ||
|
||
/// Default constructor is deleted | ||
KalmanFitter() = delete; | ||
|
@@ -220,13 +226,6 @@ class KalmanFitter { | |
/// Owned logging instance | ||
std::shared_ptr<const Logger> m_logger; | ||
|
||
/// The navigator type | ||
using KalmanNavigator = typename decltype(m_propagator)::Navigator; | ||
|
||
/// The navigator has DirectNavigator type or not | ||
static constexpr bool isDirectNavigator = | ||
std::is_same<KalmanNavigator, DirectNavigator>::value; | ||
|
||
/// @brief Propagator Actor plugin for the KalmanFilter | ||
/// | ||
/// @tparam source_link_t is an type fulfilling the @c SourceLinkConcept | ||
|
@@ -297,8 +296,9 @@ class KalmanFitter { | |
// -> Get the measurement / calibrate | ||
// -> Create the predicted state | ||
// -> Perform the kalman update | ||
// -> Check outlier behavior (@todo) | ||
// -> Fill strack state information & update stepper information | ||
// -> Check outlier behavior | ||
// -> Fill strack state information & update stepper information if | ||
// non-outlier | ||
if (state.stepping.navDir == forward and not result.smoothed and | ||
not result.forwardFiltered) { | ||
ACTS_VERBOSE("Perform forward filter step"); | ||
|
@@ -328,15 +328,15 @@ class KalmanFitter { | |
if (backwardFiltering and not result.forwardFiltered) { | ||
ACTS_VERBOSE("Forward filtering done"); | ||
result.forwardFiltered = true; | ||
// Run backward filtering | ||
// Start to run backward filtering: | ||
// Reverse navigation direction and reset navigation and stepping | ||
// state to last measurement | ||
ACTS_VERBOSE("Reverse navigation direction."); | ||
reverse(state, stepper, result); | ||
} else if (not result.smoothed) { | ||
// -> Sort the track states (as now the path length is set) | ||
// -> Call the smoothing | ||
// -> Set a stop condition when all track states have been handled | ||
// --> Search the starting state to run the smoothing | ||
// --> Call the smoothing | ||
// --> Set a stop condition when all track states have been handled | ||
ACTS_VERBOSE("Finalize/run smoothing"); | ||
auto res = finalize(state, stepper, result); | ||
if (!res.ok()) { | ||
|
@@ -802,34 +802,30 @@ class KalmanFitter { | |
// Remember you smoothed the track states | ||
result.smoothed = true; | ||
|
||
// Get the index of measurement states; | ||
// Get the indices of measurement states; | ||
std::vector<size_t> measurementIndices; | ||
auto lastState = result.fittedStates.getTrackState(result.trackTip); | ||
if (lastState.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) { | ||
measurementIndices.push_back(result.trackTip); | ||
} | ||
measurementIndices.reserve(result.measurementStates); | ||
// Count track states to be smoothed | ||
size_t nStates = 0; | ||
result.fittedStates.applyBackwards(result.trackTip, [&](auto st) { | ||
// Smoothing will start from the last measurement state | ||
if (measurementIndices.empty()) { | ||
// No smoothed parameter for the last few non-measurment states | ||
bool isMeasurement = | ||
st.typeFlags().test(TrackStateFlag::MeasurementFlag); | ||
if (isMeasurement) { | ||
measurementIndices.emplace_back(st.index()); | ||
} else if (measurementIndices.empty()) { | ||
// No smoothed parameters if the last measurement state has not been | ||
// found yet | ||
st.data().ismoothed = detail_lt::IndexData::kInvalid; | ||
} else { | ||
nStates++; | ||
} | ||
size_t iprevious = st.previous(); | ||
if (iprevious != Acts::detail_lt::IndexData::kInvalid) { | ||
auto previousState = result.fittedStates.getTrackState(iprevious); | ||
if (previousState.typeFlags().test( | ||
Acts::TrackStateFlag::MeasurementFlag)) { | ||
measurementIndices.push_back(iprevious); | ||
} | ||
// Start count when the last measurement state is found | ||
if (not measurementIndices.empty()) { | ||
nStates++; | ||
} | ||
}); | ||
// Return error if the track has no measurement states (but this should | ||
// not happen) | ||
if (measurementIndices.empty()) { | ||
ACTS_ERROR("Smoothing for a track without measurements."); | ||
return KalmanFitterError::SmoothFailed; | ||
} | ||
// Screen output for debugging | ||
|
@@ -982,6 +978,7 @@ class KalmanFitter { | |
auto result = m_propagator.template propagate(sParameters, kalmanOptions); | ||
|
||
if (!result.ok()) { | ||
ACTS_ERROR("Propapation failed: " << result.error()); | ||
return result.error(); | ||
} | ||
|
||
|
@@ -997,6 +994,7 @@ class KalmanFitter { | |
} | ||
|
||
if (!kalmanResult.result.ok()) { | ||
ACTS_ERROR("KalmanFilter failed: " << kalmanResult.result.error()); | ||
return kalmanResult.result.error(); | ||
} | ||
|
||
|
@@ -1085,6 +1083,7 @@ class KalmanFitter { | |
auto result = m_propagator.template propagate(sParameters, kalmanOptions); | ||
|
||
if (!result.ok()) { | ||
ACTS_ERROR("Propapation failed: " << result.error()); | ||
return result.error(); | ||
} | ||
|
||
|
@@ -1100,6 +1099,7 @@ class KalmanFitter { | |
} | ||
|
||
if (!kalmanResult.result.ok()) { | ||
ACTS_ERROR("KalmanFilter failed: " << kalmanResult.result.error()); | ||
return kalmanResult.result.error(); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the meaningful error output messages. |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,7 +72,7 @@ FW::EventGenerator::Config FW::Options::readPythia8Options( | |
Pythia8Generator::makeFunction(hard, lvl)}, | ||
{PoissonMultiplicityGenerator{mu}, | ||
GaussianVertexGenerator{{vtxStdXY, vtxStdXY, vtxStdZ, vtxStdT}}, | ||
Pythia8Generator::makeFunction(hard, lvl)}, | ||
Pythia8Generator::makeFunction(pileup, lvl)}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well spotted! |
||
}; | ||
cfg.shuffle = vm["evg-shuffle"].as<bool>(); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import ROOT | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we will have to do a cleanup of the scripts, probably as preparation for the Tutorial/Workshop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea! |
||
import csv | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
# Data preparation | ||
ptDict = {} | ||
|
||
# Open the output file | ||
with open('output.log', mode='r') as csv_file: | ||
csv_reader = csv.reader(csv_file, delimiter=',') | ||
# read lines and go for it | ||
for csv_row in csv_reader : | ||
if len(csv_row) > 1 : | ||
# get the job id | ||
jobID = csv_row[0] | ||
# etabin, pt value, exec time | ||
etabin = float(csv_row[1]) | ||
ptvalue = float(csv_row[2]) | ||
exectime = float(csv_row[3]) | ||
|
||
# Make sure you have all the keys ready | ||
try : | ||
pdict = ptDict[ptvalue] | ||
except : | ||
ptDict[ptvalue] = {} | ||
pdict = ptDict[ptvalue] | ||
|
||
# Now fill the sub dictionary | ||
try : | ||
vpdict = pdict[etabin] | ||
except: | ||
pdict[etabin] = [] | ||
vpdict = pdict[etabin] | ||
|
||
vpdict += [ exectime ] | ||
|
||
# plot the ptDict | ||
plt.figure(figsize=(7, 5)) | ||
|
||
ax = plt.subplot(111) | ||
plt.loglog(ptDict.keys(),[i[0][0] for i in np.array(list(ptDict.values()))],'.-', label='0<$\eta$<0.5') | ||
plt.loglog(ptDict.keys(),[i[1][0] for i in np.array(list(ptDict.values()))],'.-', label='0.5<$\eta$<1.0') | ||
plt.loglog(ptDict.keys(),[i[2][0] for i in np.array(list(ptDict.values()))],'.-', label='1.0<$\eta$<1.5') | ||
plt.loglog(ptDict.keys(),[i[3][0] for i in np.array(list(ptDict.values()))],'.-', label='1.5<$\eta$<2.0') | ||
plt.loglog(ptDict.keys(),[i[4][0] for i in np.array(list(ptDict.values()))],'.-', label='2.0<$\eta$<2.5') | ||
ax.set_xlabel('$p_T$ [GeV/c]') | ||
ax.set_ylabel('time/track [sec]') | ||
plt.yscale('log') | ||
ax.set_xlim((0.09,105)) | ||
plt.legend(numpoints=1) | ||
|
||
plt.suptitle("KF timing vs. $p_T$") | ||
|
||
plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just shifting lines, hence ok.