From 7a27f7c40106b8626c2bf375ed9f50474f2d7269 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Mon, 27 Nov 2023 18:04:25 +0100 Subject: [PATCH 1/3] update --- .../TrackFindingAlgorithmExaTrkX.hpp | 6 ++ .../src/TrackFindingAlgorithmExaTrkX.cpp | 35 +++++++++++- Examples/Algorithms/Utilities/CMakeLists.txt | 1 + .../ActsExamples/Utilities/HitSelector.hpp | 50 ++++++++++++++++ .../Algorithms/Utilities/src/HitSelector.cpp | 33 +++++++++++ Examples/Io/Csv/CMakeLists.txt | 1 + .../Io/Csv/CsvExaTrkXGraphWriter.hpp | 57 +++++++++++++++++++ Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp | 56 ++++++++++++++++++ Examples/Python/src/ExaTrkXTrackFinding.cpp | 6 +- Examples/Python/src/Output.cpp | 7 ++- Examples/Python/src/TruthTracking.cpp | 4 ++ Plugins/ExaTrkX/CMakeLists.txt | 1 + .../Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp | 5 +- .../Plugins/ExaTrkX/TorchGraphStoreHook.hpp | 34 +++++++++++ .../ExaTrkX/TorchTruthGraphMetricsHook.hpp | 3 +- Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp | 4 +- Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp | 33 +++++++++++ .../src/TorchTruthGraphMetricsHook.cpp | 3 +- 18 files changed, 326 insertions(+), 13 deletions(-) create mode 100644 Examples/Algorithms/Utilities/include/ActsExamples/Utilities/HitSelector.hpp create mode 100644 Examples/Algorithms/Utilities/src/HitSelector.cpp create mode 100644 Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp create mode 100644 Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp create mode 100644 Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp create mode 100644 Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp index d6366d25b43..6e01c7c35dd 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp @@ -11,6 +11,7 @@ #include "Acts/Definitions/Units.hpp" #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp" #include "Acts/Plugins/ExaTrkX/Stages.hpp" +#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "ActsExamples/EventData/Cluster.hpp" #include "ActsExamples/EventData/ProtoTrack.hpp" #include "ActsExamples/EventData/SimHit.hpp" @@ -52,6 +53,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { /// Output protoTracks collection. std::string outputProtoTracks; + /// Output graph (optional) + std::string outputGraph; + std::shared_ptr graphConstructor; std::vector> edgeClassifiers; @@ -114,6 +118,8 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { WriteDataHandle m_outputProtoTracks{this, "OutputProtoTracks"}; + WriteDataHandle m_outputGraph{ + this, "OutputGraph"}; // for truth graph ReadDataHandle m_inputSimHits{this, "InputSimHits"}; diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index 4215828d62c..819d55e969d 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -9,6 +9,7 @@ #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp" #include "Acts/Definitions/Units.hpp" +#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp" #include "Acts/Utilities/Zip.hpp" #include "ActsExamples/EventData/Index.hpp" @@ -31,6 +32,7 @@ class ExamplesEdmHook : public Acts::ExaTrkXHook { std::unique_ptr m_logger; std::unique_ptr m_truthGraphHook; std::unique_ptr m_targetGraphHook; + std::unique_ptr m_graphStoreHook; const Acts::Logger& logger() const { return *m_logger; } @@ -98,17 +100,22 @@ class ExamplesEdmHook : public Acts::ExaTrkXHook { truthGraph, logger.clone()); m_targetGraphHook = std::make_unique( targetGraph, logger.clone()); + m_graphStoreHook = std::make_unique(); } ~ExamplesEdmHook() {} - void operator()(const std::any& nodes, const std::any& edges) const override { + auto storedGraph() const { return m_graphStoreHook->storedGraph(); } + + void operator()(const std::any& nodes, const std::any& edges, + const std::any& weights) const override { ACTS_INFO("Metrics for total graph:"); - (*m_truthGraphHook)(nodes, edges); + (*m_truthGraphHook)(nodes, edges, weights); ACTS_INFO("Metrics for target graph (pT > " << m_targetPT / Acts::UnitConstants::GeV << " GeV, nHits >= " << m_targetSize << "):"); - (*m_targetGraphHook)(nodes, edges); + (*m_targetGraphHook)(nodes, edges, weights); + (*m_graphStoreHook)(nodes, edges, weights); } }; @@ -153,6 +160,8 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( m_inputParticles.maybeInitialize(m_cfg.inputParticles); m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap); + m_outputGraph.maybeInitialize(m_cfg.outputGraph); + // reserve space for timing m_timing.classifierTimes.resize( m_cfg.edgeClassifiers.size(), @@ -267,15 +276,35 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( // Make the prototracks std::vector protoTracks; protoTracks.reserve(trackCandidates.size()); + + int nShortTracks = 0; + for (auto& x : trackCandidates) { + if (x.size() < 3) { + nShortTracks++; + continue; + } + ProtoTrack onetrack; + onetrack.reserve(x.size()); + std::copy(x.begin(), x.end(), std::back_inserter(onetrack)); protoTracks.push_back(std::move(onetrack)); } + ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits"); ACTS_INFO("Created " << protoTracks.size() << " proto tracks"); m_outputProtoTracks(ctx, std::move(protoTracks)); + if (auto dhook = dynamic_cast(&*hook); + dhook && m_outputGraph.isInitialized()) { + auto graph = dhook->storedGraph(); + std::transform( + graph.first.begin(), graph.first.end(), graph.first.begin(), + [&](const auto& a) -> int64_t { return spacepointIDs.at(a); }); + m_outputGraph(ctx, std::move(graph)); + } + return ActsExamples::ProcessCode::SUCCESS; } diff --git a/Examples/Algorithms/Utilities/CMakeLists.txt b/Examples/Algorithms/Utilities/CMakeLists.txt index c99737e1305..855f3bd49b7 100644 --- a/Examples/Algorithms/Utilities/CMakeLists.txt +++ b/Examples/Algorithms/Utilities/CMakeLists.txt @@ -5,6 +5,7 @@ add_library( src/TrajectoriesToPrototracks.cpp src/TrackSelectorAlgorithm.cpp src/TracksToTrajectories.cpp + src/HitSelector.cpp src/TracksToParameters.cpp) target_include_directories( ActsExamplesUtilities diff --git a/Examples/Algorithms/Utilities/include/ActsExamples/Utilities/HitSelector.hpp b/Examples/Algorithms/Utilities/include/ActsExamples/Utilities/HitSelector.hpp new file mode 100644 index 00000000000..ee356445670 --- /dev/null +++ b/Examples/Algorithms/Utilities/include/ActsExamples/Utilities/HitSelector.hpp @@ -0,0 +1,50 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2019-2023 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/TrackFinding/TrackSelector.hpp" +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/SimHit.hpp" +#include "ActsExamples/Framework/DataHandle.hpp" +#include "ActsExamples/Framework/IAlgorithm.hpp" + +#include +#include +#include + +namespace ActsExamples { + +/// Select tracks by applying some selection cuts. +class HitSelector final : public IAlgorithm { + public: + struct Config { + /// Input track collection. + std::string inputHits; + /// Output track collection + std::string outputHits; + + /// Time cut + double maxTime = std::numeric_limits::max(); + }; + + HitSelector(const Config& config, Acts::Logging::Level level); + + ProcessCode execute(const AlgorithmContext& ctx) const final; + + /// Get readonly access to the config parameters + const Config& config() const { return m_cfg; } + + private: + Config m_cfg; + + ReadDataHandle m_inputHits{this, "InputHits"}; + WriteDataHandle m_outputHits{this, "OutputHits"}; +}; + +} // namespace ActsExamples diff --git a/Examples/Algorithms/Utilities/src/HitSelector.cpp b/Examples/Algorithms/Utilities/src/HitSelector.cpp new file mode 100644 index 00000000000..1da7e059483 --- /dev/null +++ b/Examples/Algorithms/Utilities/src/HitSelector.cpp @@ -0,0 +1,33 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2019-2023 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "ActsExamples/Utilities/HitSelector.hpp" + +ActsExamples::HitSelector::HitSelector(const Config& config, + Acts::Logging::Level level) + : IAlgorithm("HitSelector", level), m_cfg(config) { + m_inputHits.initialize(m_cfg.inputHits); + m_outputHits.initialize(m_cfg.outputHits); +} + +ActsExamples::ProcessCode ActsExamples::HitSelector::execute( + const ActsExamples::AlgorithmContext& ctx) const { + const auto& hits = m_inputHits(ctx); + SimHitContainer selectedHits; + + std::copy_if(hits.begin(), hits.end(), + std::inserter(selectedHits, selectedHits.begin()), + [&](const auto& hit) { return hit.time() < m_cfg.maxTime; }); + + ACTS_DEBUG("selected " << selectedHits.size() << " from " << hits.size() + << " hits"); + + m_outputHits(ctx, std::move(selectedHits)); + + return {}; +} diff --git a/Examples/Io/Csv/CMakeLists.txt b/Examples/Io/Csv/CMakeLists.txt index d8faabbcc36..8d07c38a8cf 100644 --- a/Examples/Io/Csv/CMakeLists.txt +++ b/Examples/Io/Csv/CMakeLists.txt @@ -18,6 +18,7 @@ add_library( src/CsvTrackWriter.cpp src/CsvProtoTrackWriter.cpp src/CsvSpacePointWriter.cpp + src/CsvExaTrkXGraphWriter.cpp src/CsvBFieldWriter.cpp) target_include_directories( ActsExamplesIoCsv diff --git a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp new file mode 100644 index 00000000000..9a420f90b23 --- /dev/null +++ b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp @@ -0,0 +1,57 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2020 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/Framework/ProcessCode.hpp" +#include "ActsExamples/Framework/WriterT.hpp" +#include "ActsExamples/Utilities/Paths.hpp" + +#include +#include +#include + +namespace ActsExamples { +struct AlgorithmContext; + +class CsvExaTrkXGraphWriter final + : public WriterT, std::vector>> { + public: + struct Config { + /// Which simulated (truth) hits collection to use. + std::string inputGraph; + /// Where to place output files + std::string outputDir; + /// Output filename stem. + std::string outputStem = "exatrkx-graph"; + }; + + /// Construct the cluster writer. + /// + /// @param config is the configuration object + /// @param level is the logging level + CsvExaTrkXGraphWriter(const Config& config, Acts::Logging::Level level); + + /// Readonly access to the config + const Config& config() const { return m_cfg; } + + protected: + /// Type-specific write implementation. + /// + /// @param[in] ctx is the algorithm context + /// @param[in] simHits are the simhits to be written + ProcessCode writeT(const AlgorithmContext& ctx, + const std::pair, std::vector>& + graph) override; + + private: + Config m_cfg; +}; + +} // namespace ActsExamples diff --git a/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp new file mode 100644 index 00000000000..9ee9998b524 --- /dev/null +++ b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp @@ -0,0 +1,56 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2020 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp" + +#include "Acts/Definitions/Algebra.hpp" +#include "Acts/Definitions/Common.hpp" +#include "Acts/Definitions/Units.hpp" +#include "ActsExamples/Framework/AlgorithmContext.hpp" +#include "ActsExamples/Utilities/Paths.hpp" +#include "ActsFatras/EventData/Barcode.hpp" + +#include +#include + +#include +#include + +struct GraphData { + int64_t edge0; + int64_t edge1; + float weight; + DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight); +}; + +ActsExamples::CsvExaTrkXGraphWriter::CsvExaTrkXGraphWriter( + const ActsExamples::CsvExaTrkXGraphWriter::Config& config, + Acts::Logging::Level level) + : WriterT(config.inputGraph, "CsvExaTrkXGraphWriter", level), + m_cfg(config) {} + +ActsExamples::ProcessCode ActsExamples::CsvExaTrkXGraphWriter::writeT( + const ActsExamples::AlgorithmContext& ctx, + const std::pair, std::vector>& graph) { + std::string path = perEventFilepath( + m_cfg.outputDir, m_cfg.outputStem + ".csv", ctx.eventNumber); + + dfe::NamedTupleCsvWriter writer(path); + + const auto& [edges, weights] = graph; + + for (auto i = 0ul; i < weights.size(); ++i) { + GraphData edge; + edge.edge0 = edges[2 * i]; + edge.edge1 = edges[2 * i + 1]; + edge.weight = weights[i]; + writer.append(edge); + } + + return ActsExamples::ProcessCode::SUCCESS; +} diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index ad27045da03..3131754e802 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -169,9 +169,9 @@ void addExaTrkXTrackFinding(Context &ctx) { ActsExamples::TrackFindingAlgorithmExaTrkX, mex, "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits, inputParticles, inputClusters, inputMeasurementSimhitsMap, - outputProtoTracks, graphConstructor, edgeClassifiers, trackBuilder, - rScale, phiScale, zScale, cellCountScale, cellSumScale, clusterXScale, - clusterYScale, targetMinHits, targetMinPT); + outputProtoTracks, outputGraph, graphConstructor, edgeClassifiers, + trackBuilder, rScale, phiScale, zScale, cellCountScale, cellSumScale, + clusterXScale, clusterYScale, targetMinHits, targetMinPT); { auto cls = diff --git a/Examples/Python/src/Output.cpp b/Examples/Python/src/Output.cpp index cb4b6c4694f..ee28d2bf60e 100644 --- a/Examples/Python/src/Output.cpp +++ b/Examples/Python/src/Output.cpp @@ -14,6 +14,7 @@ #include "ActsExamples/Digitization/DigitizationConfig.hpp" #include "ActsExamples/Framework/ProcessCode.hpp" #include "ActsExamples/Io/Csv/CsvBFieldWriter.hpp" +#include "ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp" #include "ActsExamples/Io/Csv/CsvMeasurementWriter.hpp" #include "ActsExamples/Io/Csv/CsvParticleWriter.hpp" #include "ActsExamples/Io/Csv/CsvPlanarClusterWriter.hpp" @@ -379,7 +380,7 @@ void addOutput(Context& ctx) { inputParticles, inputMeasurementParticlesMap, filePath, fileMode, effPlotToolConfig, fakeRatePlotToolConfig, duplicationPlotToolConfig, - trackSummaryPlotToolConfig, duplicatedPredictor); + trackSummaryPlotToolConfig, duplicatedPredictor, truthMatchProbMin, doubleMatching); ACTS_PYTHON_DECLARE_WRITER( ActsExamples::RootNuclearInteractionParametersWriter, mex, @@ -411,5 +412,9 @@ void addOutput(Context& ctx) { register_csv_bfield_writer_binding(w); register_csv_bfield_writer_binding(w); } + + ACTS_PYTHON_DECLARE_WRITER(ActsExamples::CsvExaTrkXGraphWriter, mex, + "CsvExaTrkXGraphWriter", inputGraph, outputDir, + outputStem); } } // namespace Acts::Python diff --git a/Examples/Python/src/TruthTracking.cpp b/Examples/Python/src/TruthTracking.cpp index a49a2e972d7..dc360d3bf9a 100644 --- a/Examples/Python/src/TruthTracking.cpp +++ b/Examples/Python/src/TruthTracking.cpp @@ -17,6 +17,7 @@ #include "ActsExamples/TruthTracking/TruthSeedingAlgorithm.hpp" #include "ActsExamples/TruthTracking/TruthTrackFinder.hpp" #include "ActsExamples/TruthTracking/TruthVertexFinder.hpp" +#include "ActsExamples/Utilities/HitSelector.hpp" #include "ActsExamples/Utilities/Range.hpp" #include @@ -195,6 +196,9 @@ void addTruthTracking(Context& ctx) { ActsExamples::TruthSeedingAlgorithm, mex, "TruthSeedingAlgorithm", inputParticles, inputMeasurementParticlesMap, inputSpacePoints, outputParticles, outputSeeds, outputProtoTracks, deltaRMin, deltaRMax); + + ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::HitSelector, mex, "HitSelector", + inputHits, outputHits, maxTime); } } // namespace Acts::Python diff --git a/Plugins/ExaTrkX/CMakeLists.txt b/Plugins/ExaTrkX/CMakeLists.txt index 7202deeb508..a0836829562 100644 --- a/Plugins/ExaTrkX/CMakeLists.txt +++ b/Plugins/ExaTrkX/CMakeLists.txt @@ -17,6 +17,7 @@ if(ACTS_EXATRKX_ENABLE_TORCH) src/TorchMetricLearning.cpp src/BoostTrackBuilding.cpp src/TorchTruthGraphMetricsHook.cpp + src/TorchGraphStoreHook.cpp ) endif() diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp index e6810eb22ef..19d3a5db985 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp @@ -31,8 +31,9 @@ struct ExaTrkXTiming { class ExaTrkXHook { public: - virtual ~ExaTrkXHook() {} - virtual void operator()(const std::any &, const std::any &) const {}; + virtual ~ExaTrkXHook(){}; + virtual void operator()(const std::any &, const std::any &, + const std::any &) const {}; }; class ExaTrkXPipeline { diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp new file mode 100644 index 00000000000..172b96fce20 --- /dev/null +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp @@ -0,0 +1,34 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2023 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp" +#include "Acts/Plugins/ExaTrkX/detail/CantorEdge.hpp" +#include "Acts/Utilities/Logger.hpp" + +namespace Acts { + +class TorchGraphStoreHook : public ExaTrkXHook { + public: + using Graph = std::pair, std::vector>; + + private: + std::unique_ptr m_storedGraph; + + public: + TorchGraphStoreHook(); + ~TorchGraphStoreHook() override {} + + void operator()(const std::any &, const std::any &edges, + const std::any &weights) const override; + + const Graph &storedGraph() const { return *m_storedGraph; } +}; + +} // namespace Acts diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp index a13c9de984d..f971ae2992e 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp @@ -25,7 +25,8 @@ class TorchTruthGraphMetricsHook : public ExaTrkXHook { std::unique_ptr l); ~TorchTruthGraphMetricsHook() override {} - void operator()(const std::any &, const std::any &edges) const override; + void operator()(const std::any &, const std::any &edges, + const std::any &) const override; }; } // namespace Acts diff --git a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp index 8c408413c16..3f8e88150f0 100644 --- a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp +++ b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp @@ -44,7 +44,7 @@ std::vector> ExaTrkXPipeline::run( timing->graphBuildingTime = t1 - t0; } - hook(nodes, edges); + hook(nodes, edges, {}); std::any edge_weights; timing->classifierTimes.clear(); @@ -63,7 +63,7 @@ std::vector> ExaTrkXPipeline::run( edges = std::move(newEdges); edge_weights = std::move(newWeights); - hook(nodes, edges); + hook(nodes, edges, edge_weights); } t0 = std::chrono::high_resolution_clock::now(); diff --git a/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp b/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp new file mode 100644 index 00000000000..0d5bd6b8c65 --- /dev/null +++ b/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp @@ -0,0 +1,33 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2023 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" + +#include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp" + +#include + +Acts::TorchGraphStoreHook::TorchGraphStoreHook() { + m_storedGraph = std::make_unique(); +} + +void Acts::TorchGraphStoreHook::operator()(const std::any&, + const std::any& edges, + const std::any& weights) const { + if (not weights.has_value()) { + return; + } + + m_storedGraph->first = detail::tensor2DToVector( + std::any_cast(edges).t()); + + auto cpuWeights = std::any_cast(weights).to(torch::kCPU); + m_storedGraph->second = + std::vector(cpuWeights.data_ptr(), + cpuWeights.data_ptr() + cpuWeights.numel()); +} diff --git a/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp b/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp index 50712797c39..851142274b9 100644 --- a/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp +++ b/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp @@ -46,7 +46,8 @@ Acts::TorchTruthGraphMetricsHook::TorchTruthGraphMetricsHook( } void Acts::TorchTruthGraphMetricsHook::operator()(const std::any&, - const std::any& edges) const { + const std::any& edges, + const std::any&) const { // We need to transpose the edges here for the right memory layout const auto edgeIndex = Acts::detail::tensor2DToVector( std::any_cast(edges).t()); From 833c29b07f0c37353b8cad591d2433f6faa61c12 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Mon, 27 Nov 2023 18:12:23 +0100 Subject: [PATCH 2/3] update --- .../TrackFindingAlgorithmExaTrkX.hpp | 6 -- .../src/TrackFindingAlgorithmExaTrkX.cpp | 35 +----------- Examples/Algorithms/Utilities/CMakeLists.txt | 1 - Examples/Io/Csv/CMakeLists.txt | 1 - .../Io/Csv/CsvExaTrkXGraphWriter.hpp | 57 ------------------- Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp | 56 ------------------ Examples/Python/src/ExaTrkXTrackFinding.cpp | 6 +- Examples/Python/src/Output.cpp | 7 +-- Plugins/ExaTrkX/CMakeLists.txt | 1 - .../Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp | 5 +- .../Plugins/ExaTrkX/TorchGraphStoreHook.hpp | 34 ----------- .../ExaTrkX/TorchTruthGraphMetricsHook.hpp | 3 +- Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp | 4 +- Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp | 33 ----------- .../src/TorchTruthGraphMetricsHook.cpp | 3 +- 15 files changed, 13 insertions(+), 239 deletions(-) delete mode 100644 Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp delete mode 100644 Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp delete mode 100644 Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp delete mode 100644 Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp index 6e01c7c35dd..d6366d25b43 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp @@ -11,7 +11,6 @@ #include "Acts/Definitions/Units.hpp" #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp" #include "Acts/Plugins/ExaTrkX/Stages.hpp" -#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "ActsExamples/EventData/Cluster.hpp" #include "ActsExamples/EventData/ProtoTrack.hpp" #include "ActsExamples/EventData/SimHit.hpp" @@ -53,9 +52,6 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { /// Output protoTracks collection. std::string outputProtoTracks; - /// Output graph (optional) - std::string outputGraph; - std::shared_ptr graphConstructor; std::vector> edgeClassifiers; @@ -118,8 +114,6 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { WriteDataHandle m_outputProtoTracks{this, "OutputProtoTracks"}; - WriteDataHandle m_outputGraph{ - this, "OutputGraph"}; // for truth graph ReadDataHandle m_inputSimHits{this, "InputSimHits"}; diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index 819d55e969d..4215828d62c 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -9,7 +9,6 @@ #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp" #include "Acts/Definitions/Units.hpp" -#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp" #include "Acts/Utilities/Zip.hpp" #include "ActsExamples/EventData/Index.hpp" @@ -32,7 +31,6 @@ class ExamplesEdmHook : public Acts::ExaTrkXHook { std::unique_ptr m_logger; std::unique_ptr m_truthGraphHook; std::unique_ptr m_targetGraphHook; - std::unique_ptr m_graphStoreHook; const Acts::Logger& logger() const { return *m_logger; } @@ -100,22 +98,17 @@ class ExamplesEdmHook : public Acts::ExaTrkXHook { truthGraph, logger.clone()); m_targetGraphHook = std::make_unique( targetGraph, logger.clone()); - m_graphStoreHook = std::make_unique(); } ~ExamplesEdmHook() {} - auto storedGraph() const { return m_graphStoreHook->storedGraph(); } - - void operator()(const std::any& nodes, const std::any& edges, - const std::any& weights) const override { + void operator()(const std::any& nodes, const std::any& edges) const override { ACTS_INFO("Metrics for total graph:"); - (*m_truthGraphHook)(nodes, edges, weights); + (*m_truthGraphHook)(nodes, edges); ACTS_INFO("Metrics for target graph (pT > " << m_targetPT / Acts::UnitConstants::GeV << " GeV, nHits >= " << m_targetSize << "):"); - (*m_targetGraphHook)(nodes, edges, weights); - (*m_graphStoreHook)(nodes, edges, weights); + (*m_targetGraphHook)(nodes, edges); } }; @@ -160,8 +153,6 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( m_inputParticles.maybeInitialize(m_cfg.inputParticles); m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap); - m_outputGraph.maybeInitialize(m_cfg.outputGraph); - // reserve space for timing m_timing.classifierTimes.resize( m_cfg.edgeClassifiers.size(), @@ -276,35 +267,15 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( // Make the prototracks std::vector protoTracks; protoTracks.reserve(trackCandidates.size()); - - int nShortTracks = 0; - for (auto& x : trackCandidates) { - if (x.size() < 3) { - nShortTracks++; - continue; - } - ProtoTrack onetrack; - onetrack.reserve(x.size()); - std::copy(x.begin(), x.end(), std::back_inserter(onetrack)); protoTracks.push_back(std::move(onetrack)); } - ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits"); ACTS_INFO("Created " << protoTracks.size() << " proto tracks"); m_outputProtoTracks(ctx, std::move(protoTracks)); - if (auto dhook = dynamic_cast(&*hook); - dhook && m_outputGraph.isInitialized()) { - auto graph = dhook->storedGraph(); - std::transform( - graph.first.begin(), graph.first.end(), graph.first.begin(), - [&](const auto& a) -> int64_t { return spacepointIDs.at(a); }); - m_outputGraph(ctx, std::move(graph)); - } - return ActsExamples::ProcessCode::SUCCESS; } diff --git a/Examples/Algorithms/Utilities/CMakeLists.txt b/Examples/Algorithms/Utilities/CMakeLists.txt index 855f3bd49b7..c99737e1305 100644 --- a/Examples/Algorithms/Utilities/CMakeLists.txt +++ b/Examples/Algorithms/Utilities/CMakeLists.txt @@ -5,7 +5,6 @@ add_library( src/TrajectoriesToPrototracks.cpp src/TrackSelectorAlgorithm.cpp src/TracksToTrajectories.cpp - src/HitSelector.cpp src/TracksToParameters.cpp) target_include_directories( ActsExamplesUtilities diff --git a/Examples/Io/Csv/CMakeLists.txt b/Examples/Io/Csv/CMakeLists.txt index 8d07c38a8cf..d8faabbcc36 100644 --- a/Examples/Io/Csv/CMakeLists.txt +++ b/Examples/Io/Csv/CMakeLists.txt @@ -18,7 +18,6 @@ add_library( src/CsvTrackWriter.cpp src/CsvProtoTrackWriter.cpp src/CsvSpacePointWriter.cpp - src/CsvExaTrkXGraphWriter.cpp src/CsvBFieldWriter.cpp) target_include_directories( ActsExamplesIoCsv diff --git a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp deleted file mode 100644 index 9a420f90b23..00000000000 --- a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp +++ /dev/null @@ -1,57 +0,0 @@ -// This file is part of the Acts project. -// -// Copyright (C) 2020 CERN for the benefit of the Acts project -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -#pragma once - -#include "Acts/Utilities/Logger.hpp" -#include "ActsExamples/Framework/ProcessCode.hpp" -#include "ActsExamples/Framework/WriterT.hpp" -#include "ActsExamples/Utilities/Paths.hpp" - -#include -#include -#include - -namespace ActsExamples { -struct AlgorithmContext; - -class CsvExaTrkXGraphWriter final - : public WriterT, std::vector>> { - public: - struct Config { - /// Which simulated (truth) hits collection to use. - std::string inputGraph; - /// Where to place output files - std::string outputDir; - /// Output filename stem. - std::string outputStem = "exatrkx-graph"; - }; - - /// Construct the cluster writer. - /// - /// @param config is the configuration object - /// @param level is the logging level - CsvExaTrkXGraphWriter(const Config& config, Acts::Logging::Level level); - - /// Readonly access to the config - const Config& config() const { return m_cfg; } - - protected: - /// Type-specific write implementation. - /// - /// @param[in] ctx is the algorithm context - /// @param[in] simHits are the simhits to be written - ProcessCode writeT(const AlgorithmContext& ctx, - const std::pair, std::vector>& - graph) override; - - private: - Config m_cfg; -}; - -} // namespace ActsExamples diff --git a/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp deleted file mode 100644 index 9ee9998b524..00000000000 --- a/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// This file is part of the Acts project. -// -// Copyright (C) 2020 CERN for the benefit of the Acts project -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -#include "ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp" - -#include "Acts/Definitions/Algebra.hpp" -#include "Acts/Definitions/Common.hpp" -#include "Acts/Definitions/Units.hpp" -#include "ActsExamples/Framework/AlgorithmContext.hpp" -#include "ActsExamples/Utilities/Paths.hpp" -#include "ActsFatras/EventData/Barcode.hpp" - -#include -#include - -#include -#include - -struct GraphData { - int64_t edge0; - int64_t edge1; - float weight; - DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight); -}; - -ActsExamples::CsvExaTrkXGraphWriter::CsvExaTrkXGraphWriter( - const ActsExamples::CsvExaTrkXGraphWriter::Config& config, - Acts::Logging::Level level) - : WriterT(config.inputGraph, "CsvExaTrkXGraphWriter", level), - m_cfg(config) {} - -ActsExamples::ProcessCode ActsExamples::CsvExaTrkXGraphWriter::writeT( - const ActsExamples::AlgorithmContext& ctx, - const std::pair, std::vector>& graph) { - std::string path = perEventFilepath( - m_cfg.outputDir, m_cfg.outputStem + ".csv", ctx.eventNumber); - - dfe::NamedTupleCsvWriter writer(path); - - const auto& [edges, weights] = graph; - - for (auto i = 0ul; i < weights.size(); ++i) { - GraphData edge; - edge.edge0 = edges[2 * i]; - edge.edge1 = edges[2 * i + 1]; - edge.weight = weights[i]; - writer.append(edge); - } - - return ActsExamples::ProcessCode::SUCCESS; -} diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index 3131754e802..ad27045da03 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -169,9 +169,9 @@ void addExaTrkXTrackFinding(Context &ctx) { ActsExamples::TrackFindingAlgorithmExaTrkX, mex, "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits, inputParticles, inputClusters, inputMeasurementSimhitsMap, - outputProtoTracks, outputGraph, graphConstructor, edgeClassifiers, - trackBuilder, rScale, phiScale, zScale, cellCountScale, cellSumScale, - clusterXScale, clusterYScale, targetMinHits, targetMinPT); + outputProtoTracks, graphConstructor, edgeClassifiers, trackBuilder, + rScale, phiScale, zScale, cellCountScale, cellSumScale, clusterXScale, + clusterYScale, targetMinHits, targetMinPT); { auto cls = diff --git a/Examples/Python/src/Output.cpp b/Examples/Python/src/Output.cpp index ee28d2bf60e..cb4b6c4694f 100644 --- a/Examples/Python/src/Output.cpp +++ b/Examples/Python/src/Output.cpp @@ -14,7 +14,6 @@ #include "ActsExamples/Digitization/DigitizationConfig.hpp" #include "ActsExamples/Framework/ProcessCode.hpp" #include "ActsExamples/Io/Csv/CsvBFieldWriter.hpp" -#include "ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp" #include "ActsExamples/Io/Csv/CsvMeasurementWriter.hpp" #include "ActsExamples/Io/Csv/CsvParticleWriter.hpp" #include "ActsExamples/Io/Csv/CsvPlanarClusterWriter.hpp" @@ -380,7 +379,7 @@ void addOutput(Context& ctx) { inputParticles, inputMeasurementParticlesMap, filePath, fileMode, effPlotToolConfig, fakeRatePlotToolConfig, duplicationPlotToolConfig, - trackSummaryPlotToolConfig, duplicatedPredictor, truthMatchProbMin, doubleMatching); + trackSummaryPlotToolConfig, duplicatedPredictor); ACTS_PYTHON_DECLARE_WRITER( ActsExamples::RootNuclearInteractionParametersWriter, mex, @@ -412,9 +411,5 @@ void addOutput(Context& ctx) { register_csv_bfield_writer_binding(w); register_csv_bfield_writer_binding(w); } - - ACTS_PYTHON_DECLARE_WRITER(ActsExamples::CsvExaTrkXGraphWriter, mex, - "CsvExaTrkXGraphWriter", inputGraph, outputDir, - outputStem); } } // namespace Acts::Python diff --git a/Plugins/ExaTrkX/CMakeLists.txt b/Plugins/ExaTrkX/CMakeLists.txt index a0836829562..7202deeb508 100644 --- a/Plugins/ExaTrkX/CMakeLists.txt +++ b/Plugins/ExaTrkX/CMakeLists.txt @@ -17,7 +17,6 @@ if(ACTS_EXATRKX_ENABLE_TORCH) src/TorchMetricLearning.cpp src/BoostTrackBuilding.cpp src/TorchTruthGraphMetricsHook.cpp - src/TorchGraphStoreHook.cpp ) endif() diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp index 19d3a5db985..e6810eb22ef 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp @@ -31,9 +31,8 @@ struct ExaTrkXTiming { class ExaTrkXHook { public: - virtual ~ExaTrkXHook(){}; - virtual void operator()(const std::any &, const std::any &, - const std::any &) const {}; + virtual ~ExaTrkXHook() {} + virtual void operator()(const std::any &, const std::any &) const {}; }; class ExaTrkXPipeline { diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp deleted file mode 100644 index 172b96fce20..00000000000 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp +++ /dev/null @@ -1,34 +0,0 @@ -// This file is part of the Acts project. -// -// Copyright (C) 2023 CERN for the benefit of the Acts project -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -#pragma once - -#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp" -#include "Acts/Plugins/ExaTrkX/detail/CantorEdge.hpp" -#include "Acts/Utilities/Logger.hpp" - -namespace Acts { - -class TorchGraphStoreHook : public ExaTrkXHook { - public: - using Graph = std::pair, std::vector>; - - private: - std::unique_ptr m_storedGraph; - - public: - TorchGraphStoreHook(); - ~TorchGraphStoreHook() override {} - - void operator()(const std::any &, const std::any &edges, - const std::any &weights) const override; - - const Graph &storedGraph() const { return *m_storedGraph; } -}; - -} // namespace Acts diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp index f971ae2992e..a13c9de984d 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp @@ -25,8 +25,7 @@ class TorchTruthGraphMetricsHook : public ExaTrkXHook { std::unique_ptr l); ~TorchTruthGraphMetricsHook() override {} - void operator()(const std::any &, const std::any &edges, - const std::any &) const override; + void operator()(const std::any &, const std::any &edges) const override; }; } // namespace Acts diff --git a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp index 3f8e88150f0..8c408413c16 100644 --- a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp +++ b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp @@ -44,7 +44,7 @@ std::vector> ExaTrkXPipeline::run( timing->graphBuildingTime = t1 - t0; } - hook(nodes, edges, {}); + hook(nodes, edges); std::any edge_weights; timing->classifierTimes.clear(); @@ -63,7 +63,7 @@ std::vector> ExaTrkXPipeline::run( edges = std::move(newEdges); edge_weights = std::move(newWeights); - hook(nodes, edges, edge_weights); + hook(nodes, edges); } t0 = std::chrono::high_resolution_clock::now(); diff --git a/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp b/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp deleted file mode 100644 index 0d5bd6b8c65..00000000000 --- a/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// This file is part of the Acts project. -// -// Copyright (C) 2023 CERN for the benefit of the Acts project -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" - -#include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp" - -#include - -Acts::TorchGraphStoreHook::TorchGraphStoreHook() { - m_storedGraph = std::make_unique(); -} - -void Acts::TorchGraphStoreHook::operator()(const std::any&, - const std::any& edges, - const std::any& weights) const { - if (not weights.has_value()) { - return; - } - - m_storedGraph->first = detail::tensor2DToVector( - std::any_cast(edges).t()); - - auto cpuWeights = std::any_cast(weights).to(torch::kCPU); - m_storedGraph->second = - std::vector(cpuWeights.data_ptr(), - cpuWeights.data_ptr() + cpuWeights.numel()); -} diff --git a/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp b/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp index 851142274b9..50712797c39 100644 --- a/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp +++ b/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp @@ -46,8 +46,7 @@ Acts::TorchTruthGraphMetricsHook::TorchTruthGraphMetricsHook( } void Acts::TorchTruthGraphMetricsHook::operator()(const std::any&, - const std::any& edges, - const std::any&) const { + const std::any& edges) const { // We need to transpose the edges here for the right memory layout const auto edgeIndex = Acts::detail::tensor2DToVector( std::any_cast(edges).t()); From 947fd039abfdc806943f513576cce73e30b9aab5 Mon Sep 17 00:00:00 2001 From: Benjamin Huth Date: Mon, 27 Nov 2023 18:53:19 +0100 Subject: [PATCH 3/3] fix cmakelists --- Examples/Algorithms/Utilities/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/Examples/Algorithms/Utilities/CMakeLists.txt b/Examples/Algorithms/Utilities/CMakeLists.txt index c99737e1305..855f3bd49b7 100644 --- a/Examples/Algorithms/Utilities/CMakeLists.txt +++ b/Examples/Algorithms/Utilities/CMakeLists.txt @@ -5,6 +5,7 @@ add_library( src/TrajectoriesToPrototracks.cpp src/TrackSelectorAlgorithm.cpp src/TracksToTrajectories.cpp + src/HitSelector.cpp src/TracksToParameters.cpp) target_include_directories( ActsExamplesUtilities