Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Mlpack DBScan clustering to the ML Ambiguity solver #2005

Merged
merged 32 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
af516e5
reverse ML solver changes
Corentin-Allaire Mar 16, 2023
ed8f370
ML Solver
Corentin-Allaire Mar 24, 2023
caa7fda
conflict
Corentin-Allaire Mar 24, 2023
a95d9a8
switch solver to TrackContainer
Corentin-Allaire Mar 27, 2023
f5e4cf7
conflict
Corentin-Allaire Mar 27, 2023
95a0d59
conflict
Corentin-Allaire Mar 27, 2023
7bb8c33
Merge remote-tracking branch 'upstream/main' into mlpack
Corentin-Allaire Mar 27, 2023
95f5fd3
Track Container
Corentin-Allaire Mar 27, 2023
a59408e
py format
Corentin-Allaire Mar 27, 2023
f815551
py format
Corentin-Allaire Mar 27, 2023
30fd23a
cleaning includes
Corentin-Allaire Mar 27, 2023
74c8a45
conflict
Corentin-Allaire Mar 30, 2023
72d1f4e
AmbiguityResolutionML
Corentin-Allaire Mar 30, 2023
d5c4b92
pythonbinding
Corentin-Allaire Mar 30, 2023
a60c8cc
onnx plugin
Corentin-Allaire Mar 30, 2023
fcfcbea
mlpack DBScan
Corentin-Allaire Mar 30, 2023
dd6e5d9
find mlpack
Corentin-Allaire Mar 30, 2023
5e17cdc
clustering
Corentin-Allaire Mar 30, 2023
4b7fc9b
py format
Corentin-Allaire Mar 30, 2023
bb1625c
conflict
Corentin-Allaire Mar 30, 2023
a9ed61d
conflict
Corentin-Allaire Mar 30, 2023
09a4538
cmake fix
Corentin-Allaire Mar 30, 2023
2e7daad
CMakeLists for mlpack
Corentin-Allaire Mar 30, 2023
7a4d3cd
cmake option
Corentin-Allaire Mar 31, 2023
8286dbe
Merge remote-tracking branch 'upstream/main' into mlpack
Corentin-Allaire Mar 31, 2023
001d499
cmake option
Corentin-Allaire Mar 31, 2023
bf35c49
cmake option
Corentin-Allaire Mar 31, 2023
d8628ba
Merge remote-tracking branch 'upstream/main' into mlpack
Corentin-Allaire Apr 11, 2023
25db0d9
include TrackContainer
Corentin-Allaire Apr 11, 2023
279111a
Apply alex's suggestions
Corentin-Allaire Apr 11, 2023
c8d0c42
Merge branch 'main' into mlpack
Corentin-Allaire Apr 11, 2023
8c1a266
Merge branch 'main' into mlpack
kodiakhq[bot] Apr 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ option(ACTS_BUILD_PLUGIN_JSON "Build json plugin" OFF)
option(ACTS_USE_SYSTEM_NLOHMANN_JSON "Use nlohmann::json provided by the system instead of the bundled version" ${ACTS_USE_SYSTEM_LIBS})
option(ACTS_BUILD_PLUGIN_LEGACY "Build legacy plugin" OFF)
option(ACTS_BUILD_PLUGIN_ONNX "Build ONNX plugin" OFF)
option(ACTS_BUILD_PLUGIN_MLPACK "Build MLpack plugin" OFF)
option(ACTS_SETUP_VECMEM "Explicitly set up vecmem for the project" OFF)
option(ACTS_USE_SYSTEM_VECMEM "Use a system-provided vecmem installation" ${ACTS_USE_SYSTEM_LIBS})
option(ACTS_BUILD_PLUGIN_SYCL "Build SYCL plugin" OFF)
Expand Down Expand Up @@ -181,6 +182,7 @@ set(_acts_eigen3_version 3.3.7)
set(_acts_hepmc3_version 3.2.1)
set(_acts_nlohmanjson_version 3.2.0)
set(_acts_onnxruntime_version 1.12.0)
set(_acts_mlpack_version 3.1.1)
set(_acts_root_version 6.20)
set(_acts_tbb_version 2020.1)

Expand Down Expand Up @@ -292,6 +294,10 @@ endif()
if(ACTS_BUILD_PLUGIN_ONNX)
find_package(OnnxRuntime ${_acts_onnxruntime_version} REQUIRED)
endif()
if(ACTS_BUILD_PLUGIN_MLPACK)
find_package(mlpack ${_acts_mlpack_version} REQUIRED)
include_directories(SYSTEM ${mlpack_INCLUDE_DIR})
endif()
if(ACTS_BUILD_PLUGIN_SYCL)
find_package(SYCL REQUIRED)
endif()
Expand Down
26 changes: 26 additions & 0 deletions Core/include/Acts/TrackFinding/detail/AmbiguityTrackClustering.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include <map>
#include <unordered_map>
#include <vector>

namespace Acts {
namespace detail {

/// Clusterise tracks based on shared hits
///
/// @param trackMap : Multimap storing pair of track ID and vector of measurement ID. The keys are the number of measurement and are just there to focilitate the ordering.
/// @return an unordered map representing the clusters, the keys the ID of the primary track of each cluster and the store a vector of track IDs.
std::unordered_map<int, std::vector<int>> clusterDuplicateTracks(
const std::multimap<int, std::pair<int, std::vector<int>>>& trackMap);

} // namespace detail
} // namespace Acts
45 changes: 45 additions & 0 deletions Core/src/TrackFinding/AmbiguityTrackClustering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "Acts/TrackFinding/detail/AmbiguityTrackClustering.hpp"

std::unordered_map<int, std::vector<int>> Acts::detail::clusterDuplicateTracks(
const std::multimap<int, std::pair<int, std::vector<int>>>& trackMap) {
// Unordered map associating a vector with all the track ID of a cluster to
// the ID of the first track of the cluster
std::unordered_map<int, std::vector<int>> cluster;
// Unordered map associating hits to the ID of the first track of the
// different clusters.
std::unordered_map<int, int> hitToTrack;

// Loop over all the tracks
for (auto track = trackMap.rbegin(); track != trackMap.rend(); ++track) {
std::vector<int> hits = track->second.second;
auto matchedTrack = hitToTrack.end();
// Loop over all the hits in the track
for (auto hit = hits.begin(); hit != hits.end(); hit++) {
// Check if the hit is already associated to a track
matchedTrack = hitToTrack.find(*hit);
if (matchedTrack != hitToTrack.end()) {
// Add the track to the cluster associated to the matched track
cluster.at(matchedTrack->second).push_back(track->second.first);
break;
}
}
// None of the hits have been matched to a track create a new cluster
if (matchedTrack == hitToTrack.end()) {
cluster.emplace(track->second.first,
std::vector<int>(1, track->second.first));
for (const auto& hit : hits) {
// Add the hits of the new cluster to the hitToTrack
hitToTrack.emplace(hit, track->second.first);
}
}
}
return cluster;
}
1 change: 1 addition & 0 deletions Core/src/TrackFinding/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ target_sources(
PRIVATE
CombinatorialKalmanFilterError.cpp
MeasurementSelector.cpp
AmbiguityTrackClustering.cpp
)
18 changes: 17 additions & 1 deletion Examples/Algorithms/TrackFindingML/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
set(SOURCES
src/AmbiguityResolutionML.cpp
src/AmbiguityResolutionMLAlgorithm.cpp
)

if(ACTS_BUILD_PLUGIN_MLPACK)
list(APPEND SOURCES
src/AmbiguityResolutionMLDBScanAlgorithm.cpp
)
endif()

add_library(
ActsExamplesTrackFindingML SHARED
src/AmbiguityResolutionMLAlgorithm.cpp
${SOURCES}
)

target_include_directories(
Expand All @@ -16,6 +27,11 @@ target_link_libraries(
ActsExamplesFramework
)

if(ACTS_BUILD_PLUGIN_MLPACK)
target_link_libraries(
ActsExamplesTrackFindingML PUBLIC ActsPluginmlpack)
endif()

install(
TARGETS ActsExamplesTrackFindingML
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"

#include <map>
#include <string>
#include <vector>

namespace ActsExamples {

/// Generic implementation of the machine learning ambiguity resolution
/// Conatains method for data preparations
class AmbiguityResolutionML : public IAlgorithm {
public:
/// Construct the ambiguity resolution algorithm.
///
/// @param name name of the algorithm
/// @param lvl is the logging level
AmbiguityResolutionML(std::string name, Acts::Logging::Level lvl);

protected:
/// Associated measurements ID to Tracks ID
///
/// @param tracks is the input track container
/// @param nMeasurementsMin minimum number of measurment per track
/// @return an ordered list containing pairs of track ID and associated measurement ID
std::multimap<int, std::pair<int, std::vector<int>>> mapTrackHits(
const ConstTrackContainer& tracks, int nMeasurementsMin) const;

/// Prepare the output track container to be written
///
/// @param tracks is the input track container
/// @param goodTracks is list of the IDs of all the tracks we want to keep
ConstTrackContainer prepareOutputTrack(const ConstTrackContainer& tracks,
std::vector<int>& goodTracks) const;
};

} // namespace ActsExamples
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
// This file is part of the Acts project.
//
// Copyright (C) 2022 CERN for the benefit of the Acts project
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Plugins/Onnx/OnnxRuntimeBase.hpp"
#include "Acts/Plugins/Onnx/AmbiguityTrackClassifier.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"
#include "ActsExamples/TrackFindingML/AmbiguityResolutionML.hpp"

#include <string>
#include <vector>

namespace ActsExamples {

Expand All @@ -24,14 +23,14 @@ namespace ActsExamples {
/// 1) Cluster together nearby tracks using shared hits
/// 2) For each track use a neural network to compute a score
/// 3) In each cluster keep the track with the highest score
class AmbiguityResolutionMLAlgorithm final : public IAlgorithm {
class AmbiguityResolutionMLAlgorithm final : public AmbiguityResolutionML {
public:
struct Config {
/// Input trajectories collection.
/// Input track collection.
std::string inputTracks;
/// path to the ONNX model for the duplicate neural network
/// Path to the ONNX model for the duplicate neural network
std::string inputDuplicateNN;
/// Output trajectories collection.
/// Output track collection.
std::string outputTracks;
/// Minumum number of measurement to form a track.
int nMeasurementsMin = 7;
Expand All @@ -54,10 +53,8 @@ class AmbiguityResolutionMLAlgorithm final : public IAlgorithm {

private:
Config m_cfg;
// ONNX environement
Ort::Env m_env;
// ONNX model for the duplicate neural network
Acts::OnnxRuntimeBase m_duplicateClassifier;
// ONNX model for track selection
Acts::AmbiguityTrackClassifier m_duplicateClassifier;
ReadDataHandle<ConstTrackContainer> m_inputTracks{this, "InputTracks"};
WriteDataHandle<ConstTrackContainer> m_outputTracks{this, "OutputTracks"};
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Plugins/Onnx/AmbiguityTrackClassifier.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/TrackFindingML/AmbiguityResolutionML.hpp"

#include <string>

namespace ActsExamples {

/// Evicts tracks that seem to be duplicated and fake.
///
/// The implementation works as follows:
/// 1) Cluster together nearby tracks using a DBScan
/// 2) Create subcluster based on tracks with shared hits
/// 3) For each track use a neural network to compute a score
/// 4) In each cluster keep the track with the highest score
class AmbiguityResolutionMLDBScanAlgorithm final
: public AmbiguityResolutionML {
public:
struct Config {
/// Input trajectories collection.
std::string inputTracks;
/// Path to the ONNX model for the duplicate neural network
std::string inputDuplicateNN;
/// Output trajectories collection.
std::string outputTracks;
/// Minumum number of measurement to form a track.
int nMeasurementsMin = 7;
/// Maximum distance between 2 tracks to be clustered in the DBScan
float epsilonDBScan = 0.07;
/// Minimum number of tracks to create a cluster in the DBScan
int minPointsDBScan = 2;
};

/// Construct the ambiguity resolution algorithm.
///
/// @param cfg is the algorithm configuration
/// @param lvl is the logging level
AmbiguityResolutionMLDBScanAlgorithm(Config cfg, Acts::Logging::Level lvl);

/// Run the ambiguity resolution algorithm.
///
/// @param cxt is the algorithm context with event information
/// @return a process code indication success or failure
ProcessCode execute(const AlgorithmContext& ctx) const final;

/// Const access to the config
const Config& config() const { return m_cfg; }

private:
Config m_cfg;
// ONNX model for track selection
Acts::AmbiguityTrackClassifier m_duplicateClassifier;
ReadDataHandle<ConstTrackContainer> m_inputTracks{this, "InputTracks"};
WriteDataHandle<ConstTrackContainer> m_outputTracks{this, "OutputTracks"};
};

} // namespace ActsExamples
73 changes: 73 additions & 0 deletions Examples/Algorithms/TrackFindingML/src/AmbiguityResolutionML.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// This file is part of the Acts project.
//
// Copyright (C) 2023 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "ActsExamples/TrackFindingML/AmbiguityResolutionML.hpp"

#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/Measurement.hpp"

ActsExamples::AmbiguityResolutionML::AmbiguityResolutionML(
std::string name, Acts::Logging::Level lvl)
: ActsExamples::IAlgorithm(name, lvl) {}

std::multimap<int, std::pair<int, std::vector<int>>>
ActsExamples::AmbiguityResolutionML::mapTrackHits(
const ActsExamples::ConstTrackContainer& tracks,
int nMeasurementsMin) const {
std::multimap<int, std::pair<int, std::vector<int>>> trackMap;
// Loop over all the trajectories in the events
for (const auto& track : tracks) {
std::vector<int> hits;
int nbMeasurements = 0;
// Store the hits id for the trajectory and compute the number of
// measurement
tracks.trackStateContainer().visitBackwards(
track.tipIndex(), [&](const auto& state) {
if (state.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
int indexHit = state.getUncalibratedSourceLink()
.template get<ActsExamples::IndexSourceLink>()
.index();
hits.emplace_back(indexHit);
++nbMeasurements;
}
});
if (nbMeasurements < nMeasurementsMin) {
continue;
}
trackMap.emplace(nbMeasurements, std::make_pair(track.index(), hits));
}
return trackMap;
}

ActsExamples::ConstTrackContainer
ActsExamples::AmbiguityResolutionML::prepareOutputTrack(
const ActsExamples::ConstTrackContainer& tracks,
std::vector<int>& goodTracks) const {
std::shared_ptr<Acts::ConstVectorMultiTrajectory> trackStateContainer =
tracks.trackStateContainerHolder();
auto trackContainer = std::make_shared<Acts::VectorTrackContainer>();
trackContainer->reserve(goodTracks.size());
// Temporary empty track state container: we don't change the original one,
// but we need one for filtering
auto tempTrackStateContainer =
std::make_shared<Acts::VectorMultiTrajectory>();

TrackContainer solvedTracks{trackContainer, tempTrackStateContainer};
solvedTracks.ensureDynamicColumns(tracks);

for (auto&& iTrack : goodTracks) {
auto destProxy = solvedTracks.getTrack(solvedTracks.addTrack());
destProxy.copyFrom(tracks.getTrack(iTrack));
}

ConstTrackContainer outputTracks{
std::make_shared<Acts::ConstVectorTrackContainer>(
std::move(*trackContainer)),
trackStateContainer};
return outputTracks;
}
Loading