diff --git a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp index d6366d25b43..c6eff23d627 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/include/ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp @@ -11,6 +11,7 @@ #include "Acts/Definitions/Units.hpp" #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp" #include "Acts/Plugins/ExaTrkX/Stages.hpp" +#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "ActsExamples/EventData/Cluster.hpp" #include "ActsExamples/EventData/ProtoTrack.hpp" #include "ActsExamples/EventData/SimHit.hpp" @@ -52,6 +53,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { /// Output protoTracks collection. std::string outputProtoTracks; + /// Output graph (optional) + std::string outputGraph; + std::shared_ptr graphConstructor; std::vector> edgeClassifiers; @@ -67,6 +71,9 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { float clusterXScale = 1.f; float clusterYScale = 1.f; + /// Remove track candidates with 2 or less hits + bool filterShortTracks = false; + /// Target graph properties std::size_t targetMinHits = 3; double targetMinPT = 500 * Acts::UnitConstants::MeV; @@ -114,6 +121,8 @@ class TrackFindingAlgorithmExaTrkX final : public IAlgorithm { WriteDataHandle m_outputProtoTracks{this, "OutputProtoTracks"}; + WriteDataHandle m_outputGraph{ + this, "OutputGraph"}; // for truth graph ReadDataHandle m_inputSimHits{this, "InputSimHits"}; diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index 4215828d62c..f87d9714b6f 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -9,6 +9,7 @@ #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp" #include "Acts/Definitions/Units.hpp" +#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp" #include "Acts/Utilities/Zip.hpp" #include "ActsExamples/EventData/Index.hpp" @@ -31,6 +32,7 @@ class ExamplesEdmHook : public Acts::ExaTrkXHook { std::unique_ptr m_logger; std::unique_ptr m_truthGraphHook; std::unique_ptr m_targetGraphHook; + std::unique_ptr m_graphStoreHook; const Acts::Logger& logger() const { return *m_logger; } @@ -98,17 +100,22 @@ class ExamplesEdmHook : public Acts::ExaTrkXHook { truthGraph, logger.clone()); m_targetGraphHook = std::make_unique( targetGraph, logger.clone()); + m_graphStoreHook = std::make_unique(); } ~ExamplesEdmHook() {} - void operator()(const std::any& nodes, const std::any& edges) const override { + auto storedGraph() const { return m_graphStoreHook->storedGraph(); } + + void operator()(const std::any& nodes, const std::any& edges, + const std::any& weights) const override { ACTS_INFO("Metrics for total graph:"); - (*m_truthGraphHook)(nodes, edges); + (*m_truthGraphHook)(nodes, edges, weights); ACTS_INFO("Metrics for target graph (pT > " << m_targetPT / Acts::UnitConstants::GeV << " GeV, nHits >= " << m_targetSize << "):"); - (*m_targetGraphHook)(nodes, edges); + (*m_targetGraphHook)(nodes, edges, weights); + (*m_graphStoreHook)(nodes, edges, weights); } }; @@ -153,6 +160,8 @@ ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX( m_inputParticles.maybeInitialize(m_cfg.inputParticles); m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap); + m_outputGraph.maybeInitialize(m_cfg.outputGraph); + // reserve space for timing m_timing.classifierTimes.resize( m_cfg.edgeClassifiers.size(), @@ -267,15 +276,35 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( // Make the prototracks std::vector protoTracks; protoTracks.reserve(trackCandidates.size()); + + int nShortTracks = 0; + for (auto& x : trackCandidates) { + if (m_cfg.filterShortTracks && x.size() < 3) { + nShortTracks++; + continue; + } + ProtoTrack onetrack; + onetrack.reserve(x.size()); + std::copy(x.begin(), x.end(), std::back_inserter(onetrack)); protoTracks.push_back(std::move(onetrack)); } + ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits"); ACTS_INFO("Created " << protoTracks.size() << " proto tracks"); m_outputProtoTracks(ctx, std::move(protoTracks)); + if (auto dhook = dynamic_cast(&*hook); + dhook && m_outputGraph.isInitialized()) { + auto graph = dhook->storedGraph(); + std::transform( + graph.first.begin(), graph.first.end(), graph.first.begin(), + [&](const auto& a) -> int64_t { return spacepointIDs.at(a); }); + m_outputGraph(ctx, std::move(graph)); + } + return ActsExamples::ProcessCode::SUCCESS; } diff --git a/Examples/Io/Csv/CMakeLists.txt b/Examples/Io/Csv/CMakeLists.txt index d8faabbcc36..8d07c38a8cf 100644 --- a/Examples/Io/Csv/CMakeLists.txt +++ b/Examples/Io/Csv/CMakeLists.txt @@ -18,6 +18,7 @@ add_library( src/CsvTrackWriter.cpp src/CsvProtoTrackWriter.cpp src/CsvSpacePointWriter.cpp + src/CsvExaTrkXGraphWriter.cpp src/CsvBFieldWriter.cpp) target_include_directories( ActsExamplesIoCsv diff --git a/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp new file mode 100644 index 00000000000..9a420f90b23 --- /dev/null +++ b/Examples/Io/Csv/include/ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp @@ -0,0 +1,57 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2020 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/Framework/ProcessCode.hpp" +#include "ActsExamples/Framework/WriterT.hpp" +#include "ActsExamples/Utilities/Paths.hpp" + +#include +#include +#include + +namespace ActsExamples { +struct AlgorithmContext; + +class CsvExaTrkXGraphWriter final + : public WriterT, std::vector>> { + public: + struct Config { + /// Which simulated (truth) hits collection to use. + std::string inputGraph; + /// Where to place output files + std::string outputDir; + /// Output filename stem. + std::string outputStem = "exatrkx-graph"; + }; + + /// Construct the cluster writer. + /// + /// @param config is the configuration object + /// @param level is the logging level + CsvExaTrkXGraphWriter(const Config& config, Acts::Logging::Level level); + + /// Readonly access to the config + const Config& config() const { return m_cfg; } + + protected: + /// Type-specific write implementation. + /// + /// @param[in] ctx is the algorithm context + /// @param[in] simHits are the simhits to be written + ProcessCode writeT(const AlgorithmContext& ctx, + const std::pair, std::vector>& + graph) override; + + private: + Config m_cfg; +}; + +} // namespace ActsExamples diff --git a/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp new file mode 100644 index 00000000000..f5553a20fe1 --- /dev/null +++ b/Examples/Io/Csv/src/CsvExaTrkXGraphWriter.cpp @@ -0,0 +1,56 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2020 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp" + +#include "Acts/Definitions/Algebra.hpp" +#include "Acts/Definitions/Common.hpp" +#include "Acts/Definitions/Units.hpp" +#include "ActsExamples/Framework/AlgorithmContext.hpp" +#include "ActsExamples/Utilities/Paths.hpp" +#include "ActsFatras/EventData/Barcode.hpp" + +#include +#include + +#include +#include + +struct GraphData { + int64_t edge0; + int64_t edge1; + float weight; + DFE_NAMEDTUPLE(GraphData, edge0, edge1, weight); +}; + +ActsExamples::CsvExaTrkXGraphWriter::CsvExaTrkXGraphWriter( + const ActsExamples::CsvExaTrkXGraphWriter::Config& config, + Acts::Logging::Level level) + : WriterT(config.inputGraph, "CsvExaTrkXGraphWriter", level), + m_cfg(config) {} + +ActsExamples::ProcessCode ActsExamples::CsvExaTrkXGraphWriter::writeT( + const ActsExamples::AlgorithmContext& ctx, + const std::pair, std::vector>& graph) { + std::string path = perEventFilepath( + m_cfg.outputDir, m_cfg.outputStem + ".csv", ctx.eventNumber); + + dfe::NamedTupleCsvWriter writer(path); + + const auto& [edges, weights] = graph; + + for (auto i = 0ul; i < weights.size(); ++i) { + GraphData edge{}; + edge.edge0 = edges[2 * i]; + edge.edge1 = edges[2 * i + 1]; + edge.weight = weights[i]; + writer.append(edge); + } + + return ActsExamples::ProcessCode::SUCCESS; +} diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index ad27045da03..c167ce54f67 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -169,9 +169,10 @@ void addExaTrkXTrackFinding(Context &ctx) { ActsExamples::TrackFindingAlgorithmExaTrkX, mex, "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits, inputParticles, inputClusters, inputMeasurementSimhitsMap, - outputProtoTracks, graphConstructor, edgeClassifiers, trackBuilder, - rScale, phiScale, zScale, cellCountScale, cellSumScale, clusterXScale, - clusterYScale, targetMinHits, targetMinPT); + outputProtoTracks, outputGraph, graphConstructor, edgeClassifiers, + trackBuilder, rScale, phiScale, zScale, cellCountScale, cellSumScale, + clusterXScale, clusterYScale, filterShortTracks, targetMinHits, + targetMinPT); { auto cls = diff --git a/Examples/Python/src/Output.cpp b/Examples/Python/src/Output.cpp index 560407c8638..ec06d8fc481 100644 --- a/Examples/Python/src/Output.cpp +++ b/Examples/Python/src/Output.cpp @@ -14,6 +14,7 @@ #include "ActsExamples/Digitization/DigitizationConfig.hpp" #include "ActsExamples/Framework/ProcessCode.hpp" #include "ActsExamples/Io/Csv/CsvBFieldWriter.hpp" +#include "ActsExamples/Io/Csv/CsvExaTrkXGraphWriter.hpp" #include "ActsExamples/Io/Csv/CsvMeasurementWriter.hpp" #include "ActsExamples/Io/Csv/CsvParticleWriter.hpp" #include "ActsExamples/Io/Csv/CsvPlanarClusterWriter.hpp" @@ -412,5 +413,9 @@ void addOutput(Context& ctx) { register_csv_bfield_writer_binding(w); register_csv_bfield_writer_binding(w); } + + ACTS_PYTHON_DECLARE_WRITER(ActsExamples::CsvExaTrkXGraphWriter, mex, + "CsvExaTrkXGraphWriter", inputGraph, outputDir, + outputStem); } } // namespace Acts::Python diff --git a/Plugins/ExaTrkX/CMakeLists.txt b/Plugins/ExaTrkX/CMakeLists.txt index 7202deeb508..a0836829562 100644 --- a/Plugins/ExaTrkX/CMakeLists.txt +++ b/Plugins/ExaTrkX/CMakeLists.txt @@ -17,6 +17,7 @@ if(ACTS_EXATRKX_ENABLE_TORCH) src/TorchMetricLearning.cpp src/BoostTrackBuilding.cpp src/TorchTruthGraphMetricsHook.cpp + src/TorchGraphStoreHook.cpp ) endif() diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp index e6810eb22ef..9a830bdf073 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp @@ -32,7 +32,8 @@ struct ExaTrkXTiming { class ExaTrkXHook { public: virtual ~ExaTrkXHook() {} - virtual void operator()(const std::any &, const std::any &) const {}; + virtual void operator()(const std::any &nodes, const std::any &edges, + const std::any &weights) const {}; }; class ExaTrkXPipeline { diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp new file mode 100644 index 00000000000..172b96fce20 --- /dev/null +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp @@ -0,0 +1,34 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2023 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp" +#include "Acts/Plugins/ExaTrkX/detail/CantorEdge.hpp" +#include "Acts/Utilities/Logger.hpp" + +namespace Acts { + +class TorchGraphStoreHook : public ExaTrkXHook { + public: + using Graph = std::pair, std::vector>; + + private: + std::unique_ptr m_storedGraph; + + public: + TorchGraphStoreHook(); + ~TorchGraphStoreHook() override {} + + void operator()(const std::any &, const std::any &edges, + const std::any &weights) const override; + + const Graph &storedGraph() const { return *m_storedGraph; } +}; + +} // namespace Acts diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp index a13c9de984d..f971ae2992e 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp @@ -25,7 +25,8 @@ class TorchTruthGraphMetricsHook : public ExaTrkXHook { std::unique_ptr l); ~TorchTruthGraphMetricsHook() override {} - void operator()(const std::any &, const std::any &edges) const override; + void operator()(const std::any &, const std::any &edges, + const std::any &) const override; }; } // namespace Acts diff --git a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp index 8c408413c16..3f8e88150f0 100644 --- a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp +++ b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp @@ -44,7 +44,7 @@ std::vector> ExaTrkXPipeline::run( timing->graphBuildingTime = t1 - t0; } - hook(nodes, edges); + hook(nodes, edges, {}); std::any edge_weights; timing->classifierTimes.clear(); @@ -63,7 +63,7 @@ std::vector> ExaTrkXPipeline::run( edges = std::move(newEdges); edge_weights = std::move(newWeights); - hook(nodes, edges); + hook(nodes, edges, edge_weights); } t0 = std::chrono::high_resolution_clock::now(); diff --git a/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp b/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp new file mode 100644 index 00000000000..0d5bd6b8c65 --- /dev/null +++ b/Plugins/ExaTrkX/src/TorchGraphStoreHook.cpp @@ -0,0 +1,33 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2023 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp" + +#include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp" + +#include + +Acts::TorchGraphStoreHook::TorchGraphStoreHook() { + m_storedGraph = std::make_unique(); +} + +void Acts::TorchGraphStoreHook::operator()(const std::any&, + const std::any& edges, + const std::any& weights) const { + if (not weights.has_value()) { + return; + } + + m_storedGraph->first = detail::tensor2DToVector( + std::any_cast(edges).t()); + + auto cpuWeights = std::any_cast(weights).to(torch::kCPU); + m_storedGraph->second = + std::vector(cpuWeights.data_ptr(), + cpuWeights.data_ptr() + cpuWeights.numel()); +} diff --git a/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp b/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp index 50712797c39..851142274b9 100644 --- a/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp +++ b/Plugins/ExaTrkX/src/TorchTruthGraphMetricsHook.cpp @@ -46,7 +46,8 @@ Acts::TorchTruthGraphMetricsHook::TorchTruthGraphMetricsHook( } void Acts::TorchTruthGraphMetricsHook::operator()(const std::any&, - const std::any& edges) const { + const std::any& edges, + const std::any&) const { // We need to transpose the edges here for the right memory layout const auto edgeIndex = Acts::detail::tensor2DToVector( std::any_cast(edges).t()); diff --git a/Tests/UnitTests/Plugins/ExaTrkX/ExaTrkXMetricHookTests.cpp b/Tests/UnitTests/Plugins/ExaTrkX/ExaTrkXMetricHookTests.cpp index def518202d3..31693c9b588 100644 --- a/Tests/UnitTests/Plugins/ExaTrkX/ExaTrkXMetricHookTests.cpp +++ b/Tests/UnitTests/Plugins/ExaTrkX/ExaTrkXMetricHookTests.cpp @@ -29,7 +29,7 @@ void testTruthTestGraph(std::vector &truthGraph, {static_cast(testGraph.size() / 2), 2}, opts) .transpose(0, 1); - hook({}, edgeTensor); + hook({}, edgeTensor, {}); const auto str = ss.str();