Skip to content

Commit

Permalink
feat: Fitters calculate track quantities
Browse files Browse the repository at this point in the history
Adds a central helper function that calculates number of holes,
outliers, measurements and shared hits based on the type flags, and
stores them in the track.
  • Loading branch information
paulgessinger committed Apr 4, 2023
1 parent 2235f82 commit 83aa3e9
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 3 deletions.
46 changes: 46 additions & 0 deletions Core/include/Acts/EventData/TrackHelpers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
#pragma once

#include "Acts/EventData/MultiTrajectory.hpp"
#include "Acts/EventData/TrackContainer.hpp"

namespace Acts {
template <typename track_container_t, typename track_state_container_t,
template <typename> class holder_t>
void calculateTrackQuantities(
Acts::TrackProxy<track_container_t, track_state_container_t, holder_t,
false>
track) {
track.chi2() = 0;
track.nDoF() = 0;

track.nHoles() = 0;
track.nMeasurements() = 0;
track.nSharedHits() = 0;
track.nOutliers() = 0;

for (const auto& trackState : track.trackStates()) {
auto typeFlags = trackState.typeFlags();

if (typeFlags.test(Acts::TrackStateFlag::MeasurementFlag)) {
if (typeFlags.test(Acts::TrackStateFlag::SharedHitFlag)) {
track.nSharedHits()++;
}

track.nMeasurements()++;
track.chi2() += trackState.chi2();
track.nDoF() += trackState.calibratedSize();
} else if (typeFlags.test(Acts::TrackStateFlag::OutlierFlag)) {
track.nOutliers()++;
} else if (typeFlags.test(Acts::TrackStateFlag::HoleFlag)) {
track.nHoles()++;
}
}
}
} // namespace Acts
32 changes: 30 additions & 2 deletions Core/include/Acts/EventData/TrackProxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,34 @@ class TrackProxy {
return component<unsigned int>(hashString("nHoles"));
}

/// Return a mutable reference to the number of outliers for the track.
/// Mutable version
/// @return The number of outliers
template <bool RO = ReadOnly, typename = std::enable_if_t<!RO>>
unsigned int& nOutliers() {
return component<unsigned int>(hashString("nOutliers"));
}

/// Return the number of outliers for the track. Const version
/// @return The number of outliers
unsigned int nOutliers() const {
return component<unsigned int>(hashString("nOutliers"));
}

/// Return a mutable reference to the number of shared hits for the track.
/// Mutable version
/// @return The number of shared hits
template <bool RO = ReadOnly, typename = std::enable_if_t<!RO>>
unsigned int& nSharedHits() {
return component<unsigned int>(hashString("nSharedHits"));
}

/// Return the number of shared hits for the track. Const version
/// @return The number of shared hits
unsigned int nSharedHits() const {
return component<unsigned int>(hashString("nSharedHits"));
}

/// Return a mutable reference to the chi squared
/// Mutable version
/// @return The chi squared
Expand All @@ -446,13 +474,13 @@ class TrackProxy {
/// @return The the number of degrees of freedom
template <bool RO = ReadOnly, typename = std::enable_if_t<!RO>>
unsigned int& nDoF() {
return component<unsigned int>(hashString("nHoles"));
return component<unsigned int>(hashString("ndf"));
}

/// Return the number of degrees of freedom for the track. Const version
/// @return The number of degrees of freedom
unsigned int nDoF() const {
return component<unsigned int>(hashString("nHoles"));
return component<unsigned int>(hashString("ndf"));
}

/// Return the index of this track in the track container
Expand Down
10 changes: 10 additions & 0 deletions Core/include/Acts/EventData/VectorTrackContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class VectorTrackContainerBase {
return &instance.m_chi2[itrack];
case "ndf"_hash:
return &instance.m_ndf[itrack];
case "nOutliers"_hash:
return &instance.m_nOutliers[itrack];
case "nSharedHits"_hash:
return &instance.m_nSharedHits[itrack];
default:
auto it = instance.m_dynamic.find(key);
if (it == instance.m_dynamic.end()) {
Expand Down Expand Up @@ -105,6 +109,10 @@ class VectorTrackContainerBase {
result = result && m_chi2.size() == size;
assert(result);
result = result && m_ndf.size() == size;
assert(result);
result = result && m_nOutliers.size() == size;
assert(result);
result = result && m_nSharedHits.size() == size;

for (const auto& [key, col] : m_dynamic) {
(void)key;
Expand Down Expand Up @@ -138,6 +146,8 @@ class VectorTrackContainerBase {
std::vector<unsigned int> m_nHoles;
std::vector<float> m_chi2;
std::vector<unsigned int> m_ndf;
std::vector<unsigned int> m_nOutliers;
std::vector<unsigned int> m_nSharedHits;

std::unordered_map<HashedString, std::unique_ptr<detail::DynamicColumnBase>>
m_dynamic;
Expand Down
4 changes: 4 additions & 0 deletions Core/include/Acts/TrackFinding/CombinatorialKalmanFilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "Acts/EventData/MultiTrajectory.hpp"
#include "Acts/EventData/MultiTrajectoryHelpers.hpp"
#include "Acts/EventData/TrackContainer.hpp"
#include "Acts/EventData/TrackHelpers.hpp"
#include "Acts/EventData/TrackParameters.hpp"
#include "Acts/EventData/TrackStatePropMask.hpp"
#include "Acts/Geometry/GeometryContext.hpp"
Expand Down Expand Up @@ -1379,6 +1380,9 @@ class CombinatorialKalmanFilter {
track.parameters() = parameters.parameters();
track.covariance() = *parameters.covariance();
track.setReferenceSurface(parameters.referenceSurface().getSharedPtr());

calculateTrackQuantities(track);

tracks.push_back(track);
}

Expand Down
3 changes: 3 additions & 0 deletions Core/include/Acts/TrackFitting/Chi2Fitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "Acts/EventData/MultiTrajectory.hpp"
#include "Acts/EventData/MultiTrajectoryHelpers.hpp"
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/EventData/TrackHelpers.hpp"
#include "Acts/EventData/TrackParameters.hpp"
#include "Acts/EventData/VectorMultiTrajectory.hpp"
#include "Acts/Geometry/GeometryContext.hpp"
Expand Down Expand Up @@ -791,6 +792,8 @@ class Chi2Fitter {
track.nMeasurements() = c2r.measurementStates;
track.nHoles() = c2r.measurementHoles;

calculateTrackQuantities(track);

if (trackContainer.hasColumn(hashString("chi2"))) {
track.template component<ActsScalar, hashString("chi2")>() =
c2r.chisquare;
Expand Down
3 changes: 3 additions & 0 deletions Core/include/Acts/TrackFitting/GaussianSumFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include "Acts/EventData/TrackHelpers.hpp"
#include "Acts/EventData/VectorMultiTrajectory.hpp"
#include "Acts/Propagator/EigenStepper.hpp"
#include "Acts/Propagator/MultiStepperAborters.hpp"
Expand Down Expand Up @@ -442,6 +443,8 @@ struct GaussianSumFitter {
track.setReferenceSurface(params.referenceSurface().getSharedPtr());
}

calculateTrackQuantities(track);

track.nMeasurements() = measurementStatesFinal;
track.nHoles() = fwdGsfResult.measurementHoles;

Expand Down
7 changes: 7 additions & 0 deletions Core/include/Acts/TrackFitting/KalmanFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "Acts/EventData/MultiTrajectory.hpp"
#include "Acts/EventData/MultiTrajectoryHelpers.hpp"
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/EventData/TrackHelpers.hpp"
#include "Acts/EventData/TrackParameters.hpp"
#include "Acts/EventData/VectorMultiTrajectory.hpp"
#include "Acts/Geometry/GeometryContext.hpp"
Expand Down Expand Up @@ -1126,6 +1127,9 @@ class KalmanFitter {
}
track.nMeasurements() = kalmanResult.measurementStates;
track.nHoles() = kalmanResult.measurementHoles;

calculateTrackQuantities(track);

if (trackContainer.hasColumn(hashString("smoothed"))) {
track.template component<bool, hashString("smoothed")>() =
kalmanResult.smoothed;
Expand Down Expand Up @@ -1259,6 +1263,9 @@ class KalmanFitter {
}
track.nMeasurements() = kalmanResult.measurementStates;
track.nHoles() = kalmanResult.measurementHoles;

calculateFitQuality(track);

if (trackContainer.hasColumn(hashString("smoothed"))) {
track.template component<bool, hashString("smoothed")>() =
kalmanResult.smoothed;
Expand Down
13 changes: 12 additions & 1 deletion Core/src/EventData/VectorTrackContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ VectorTrackContainerBase::VectorTrackContainerBase(
m_nMeasurements{other.m_nMeasurements},
m_nHoles{other.m_nHoles},
m_chi2{other.m_chi2},
m_ndf{other.m_ndf} {
m_ndf{other.m_ndf},
m_nOutliers{other.m_nOutliers},
m_nSharedHits{other.m_nSharedHits} {
for (const auto& [key, value] : other.m_dynamic) {
m_dynamic.insert({key, value->clone()});
}
Expand All @@ -44,6 +46,9 @@ VectorTrackContainer::IndexType VectorTrackContainer::addTrack_impl() {
m_chi2.emplace_back();
m_ndf.emplace_back();

m_nOutliers.emplace_back();
m_nSharedHits.emplace_back();

// dynamic columns
for (auto& [key, vec] : m_dynamic) {
vec->add();
Expand Down Expand Up @@ -74,6 +79,9 @@ void VectorTrackContainer::removeTrack_impl(IndexType itrack) {
erase(m_chi2);
erase(m_ndf);

erase(m_nOutliers);
erase(m_nSharedHits);

for (auto& [key, vec] : m_dynamic) {
vec->erase(itrack);
}
Expand Down Expand Up @@ -114,6 +122,9 @@ void VectorTrackContainer::reserve(IndexType size) {
m_chi2.reserve(size);
m_ndf.reserve(size);

m_nOutliers.reserve(size);
m_nSharedHits.reserve(size);

for (auto& [key, vec] : m_dynamic) {
vec->reserve(size);
}
Expand Down

0 comments on commit 83aa3e9

Please sign in to comment.