Skip to content

Commit

Permalink
Merge branch 'main' into fix/track-state-range-notip
Browse files Browse the repository at this point in the history
  • Loading branch information
kodiakhq[bot] authored Mar 28, 2023
2 parents 6d1102e + 31e5b09 commit 3b58c8d
Show file tree
Hide file tree
Showing 12 changed files with 435 additions and 184 deletions.
2 changes: 1 addition & 1 deletion Core/include/Acts/Seeding/SeedConfirmationRangeConfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct SeedConfirmationRangeConfig {
// z minimum and maximum of middle component of the seed used to define the
// region of the detector for seed confirmation
float zMinSeedConf =
std::numeric_limits<float>::min(); // Acts::UnitConstants::mm
std::numeric_limits<float>::lowest(); // Acts::UnitConstants::mm
float zMaxSeedConf =
std::numeric_limits<float>::max(); // Acts::UnitConstants::mm
// radius of bottom component of seed that is used to define the number of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

#pragma once

#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/Trajectories.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"

Expand All @@ -32,9 +31,9 @@ class AmbiguityResolutionAlgorithm final : public IAlgorithm {
public:
struct Config {
/// Input trajectories collection.
std::string inputTrajectories;
std::string inputTracks;
/// Output trajectories collection.
std::string outputTrajectories;
std::string outputTracks;

/// Maximum amount of shared hits per track.
std::uint32_t maximumSharedHits = 1;
Expand Down Expand Up @@ -62,14 +61,8 @@ class AmbiguityResolutionAlgorithm final : public IAlgorithm {

private:
Config m_cfg;

ReadDataHandle<IndexSourceLinkContainer> m_inputSourceLinks{
this, "InputSourceLinks"};
ReadDataHandle<TrajectoriesContainer> m_inputTrajectories{
this, "InputTrajectories"};

WriteDataHandle<TrajectoriesContainer> m_outputTrajectories{
this, "OutputTrajectories"};
ReadDataHandle<ConstTrackContainer> m_inputTracks{this, "InputTracks"};
WriteDataHandle<ConstTrackContainer> m_outputTracks{this, "OutputTracks"};
};

} // namespace ActsExamples
126 changes: 53 additions & 73 deletions Examples/Algorithms/TrackFinding/src/AmbiguityResolutionAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include "Acts/EventData/MultiTrajectoryHelpers.hpp"
#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/Measurement.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/EventData/Trajectories.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
#include "ActsExamples/Framework/WhiteBoard.hpp"

Expand All @@ -28,25 +26,23 @@ ActsExamples::AmbiguityResolutionAlgorithm::AmbiguityResolutionAlgorithm(
Acts::Logging::Level lvl)
: ActsExamples::IAlgorithm("AmbiguityResolutionAlgorithm", lvl),
m_cfg(std::move(cfg)) {
if (m_cfg.inputTrajectories.empty()) {
if (m_cfg.inputTracks.empty()) {
throw std::invalid_argument("Missing trajectories input collection");
}
if (m_cfg.outputTrajectories.empty()) {
if (m_cfg.outputTracks.empty()) {
throw std::invalid_argument("Missing trajectories output collection");
}

m_inputTrajectories.initialize(m_cfg.inputTrajectories);
m_outputTrajectories.initialize(m_cfg.outputTrajectories);
m_inputTracks.initialize(m_cfg.inputTracks);
m_outputTracks.initialize(m_cfg.outputTracks);
}

namespace {

struct State {
std::size_t numberOfTracks{};

std::vector<std::pair<std::size_t, std::size_t>> trackTips;
std::vector<int> trackTips;
std::vector<float> trackChi2;
std::vector<ActsExamples::TrackParameters> trackParameters;
std::vector<std::vector<std::size_t>> measurementsPerTrack;

boost::container::flat_map<std::size_t,
Expand All @@ -57,56 +53,41 @@ struct State {
boost::container::flat_set<std::size_t> selectedTracks;
};

State computeInitialState(
const ActsExamples::TrajectoriesContainer& trajectories,
std::size_t nMeasurementsMin) {
State computeInitialState(const ActsExamples::ConstTrackContainer& tracks,
std::size_t nMeasurementsMin) {
State state;

for (std::size_t iTrack = 0, iTraj = 0; iTraj < trajectories.size();
++iTraj) {
const auto& traj = trajectories[iTraj];
for (auto tip : traj.tips()) {
if (!traj.hasTrackParameters(tip)) {
continue;
}

auto trajState = Acts::MultiTrajectoryHelpers::trajectoryState(
traj.multiTrajectory(), tip);
if (trajState.nMeasurements < nMeasurementsMin) {
continue;
}

std::vector<std::size_t> measurements;
traj.multiTrajectory().visitBackwards(tip, [&](const auto& hit) {
if (hit.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
std::size_t iMeasurement =
hit.getUncalibratedSourceLink()
.template get<ActsExamples::IndexSourceLink>()
.index();
measurements.push_back(iMeasurement);
}
return true;
});

++state.numberOfTracks;

state.trackTips.emplace_back(iTraj, tip);
state.trackChi2.push_back(trajState.chi2Sum / trajState.NDF);
state.trackParameters.push_back(traj.trackParameters(tip));
state.measurementsPerTrack.push_back(std::move(measurements));

state.selectedTracks.insert(iTrack);

++iTrack;
for (const auto& track : tracks) {
auto trajState = Acts::MultiTrajectoryHelpers::trajectoryState(
tracks.trackStateContainer(), track.tipIndex());
if (trajState.nMeasurements < nMeasurementsMin) {
continue;
}
std::vector<std::size_t> measurements;
tracks.trackStateContainer().visitBackwards(
track.tipIndex(), [&](const auto& hit) {
if (hit.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
std::size_t iMeasurement =
hit.getUncalibratedSourceLink()
.template get<ActsExamples::IndexSourceLink>()
.index();
measurements.push_back(iMeasurement);
}
return true;
});

state.trackTips.push_back(track.index());
state.trackChi2.push_back(trajState.chi2Sum / trajState.NDF);
state.measurementsPerTrack.push_back(std::move(measurements));
state.selectedTracks.insert(state.numberOfTracks);

++state.numberOfTracks;
}

for (std::size_t iTrack = 0; iTrack < state.numberOfTracks; ++iTrack) {
for (auto iMeasurement : state.measurementsPerTrack[iTrack]) {
state.tracksPerMeasurement[iMeasurement].insert(iTrack);
}
}

state.sharedMeasurementsPerTrack =
std::vector<std::size_t>(state.trackTips.size(), 0);

Expand Down Expand Up @@ -138,9 +119,8 @@ void removeTrack(State& state, std::size_t iTrack) {

ActsExamples::ProcessCode ActsExamples::AmbiguityResolutionAlgorithm::execute(
const AlgorithmContext& ctx) const {
const auto& trajectories = m_inputTrajectories(ctx);

auto state = computeInitialState(trajectories, m_cfg.nMeasurementsMin);
const auto& tracks = m_inputTracks(ctx);
auto state = computeInitialState(tracks, m_cfg.nMeasurementsMin);

auto sharedMeasurementsComperator = [&state](std::size_t a, std::size_t b) {
return state.sharedMeasurementsPerTrack[a] <
Expand Down Expand Up @@ -180,27 +160,27 @@ ActsExamples::ProcessCode ActsExamples::AmbiguityResolutionAlgorithm::execute(
ACTS_INFO("Resolved to " << state.selectedTracks.size() << " tracks from "
<< state.trackTips.size());

TrajectoriesContainer outputTrajectories;
outputTrajectories.reserve(trajectories.size());
for (std::size_t iTraj = 0; iTraj < trajectories.size(); ++iTraj) {
const auto& traj = trajectories[iTraj];

std::vector<Acts::MultiTrajectoryTraits::IndexType> tips;
Trajectories::IndexedParameters parameters;

for (auto iTrack : state.selectedTracks) {
if (state.trackTips[iTrack].first != iTraj) {
continue;
}
const auto tip = state.trackTips[iTrack].second;
tips.push_back(tip);
parameters.emplace(tip, state.trackParameters[iTrack]);
}
if (!tips.empty()) {
outputTrajectories.emplace_back(traj.multiTrajectory(), tips, parameters);
}
std::shared_ptr<Acts::ConstVectorMultiTrajectory> trackStateContainer =
tracks.trackStateContainerHolder();
auto trackContainer = std::make_shared<Acts::VectorTrackContainer>();
trackContainer->reserve(state.selectedTracks.size());
// temporary empty track state container: we don't change the original one,
// but we need one for filtering
auto tempTrackStateContainer =
std::make_shared<Acts::VectorMultiTrajectory>();

TrackContainer solvedTracks{trackContainer, tempTrackStateContainer};
solvedTracks.ensureDynamicColumns(tracks);

for (auto iTrack : state.selectedTracks) {
auto destProxy = solvedTracks.getTrack(solvedTracks.addTrack());
destProxy.copyFrom(tracks.getTrack(state.trackTips.at(iTrack)));
}

m_outputTrajectories(ctx, std::move(outputTrajectories));
ActsExamples::ConstTrackContainer outputTracks{
std::make_shared<Acts::ConstVectorTrackContainer>(
std::move(*trackContainer)),
trackStateContainer};
m_outputTracks(ctx, std::move(outputTracks));
return ActsExamples::ProcessCode::SUCCESS;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#pragma once

#include "Acts/Plugins/Onnx/OnnxRuntimeBase.hpp"
#include "ActsExamples/EventData/Trajectories.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"

Expand All @@ -28,11 +28,11 @@ class AmbiguityResolutionMLAlgorithm final : public IAlgorithm {
public:
struct Config {
/// Input trajectories collection.
std::string inputTrajectories;
std::string inputTracks;
/// path to the ONNX model for the duplicate neural network
std::string inputDuplicateNN;
/// Output trajectories collection.
std::string outputTrajectories;
std::string outputTracks;
/// Minumum number of measurement to form a track.
int nMeasurementsMin = 7;
};
Expand All @@ -58,12 +58,8 @@ class AmbiguityResolutionMLAlgorithm final : public IAlgorithm {
Ort::Env m_env;
// ONNX model for the duplicate neural network
Acts::OnnxRuntimeBase m_duplicateClassifier;

ReadDataHandle<TrajectoriesContainer> m_inputTrajectories{
this, "InputTrajectories"};

WriteDataHandle<TrajectoriesContainer> m_outputTrajectories{
this, "OutputTrajectories"};
ReadDataHandle<ConstTrackContainer> m_inputTracks{this, "InputTracks"};
WriteDataHandle<ConstTrackContainer> m_outputTracks{this, "OutputTracks"};
};

} // namespace ActsExamples
Loading

0 comments on commit 3b58c8d

Please sign in to comment.