Skip to content

Commit

Permalink
Merge branch 'refactor-resurrect-detector-root-volumes' of github.com…
Browse files Browse the repository at this point in the history
…:andiwand/acts into refactor-resurrect-detector-root-volumes
  • Loading branch information
andiwand committed Apr 20, 2023
2 parents 559af7a + ed5f5ae commit 83b9615
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
3 changes: 0 additions & 3 deletions Examples/Python/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,9 +1119,6 @@ def test_full_chain_odd_example_pythia_geant4(tmp_path):
)


@pytest.mark.skip(
reason="as of https://github.com/acts-project/acts/issues/2023 disabling for now"
)
@pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
@pytest.mark.skipif(not onnxEnabled, reason="ONNX plugin not enabled")
@pytest.mark.slow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@ class AmbiguityTrackClassifier {
std::unordered_map<int, std::vector<int>>& clusters,
const Acts::TrackContainer<track_container_t, traj_t, holder_t>& tracks)
const {
// Compute the number of entry (since it is smaller than the number of
// track)
int trackNb = 0;
for (const auto& [_, val] : clusters) {
trackNb += val.size();
}
// Input of the neural network
Acts::NetworkBatchInput networkInput(tracks.size() + 1, 8);
Acts::NetworkBatchInput networkInput(trackNb, 8);
int inputID = 0;
// Get the input feature of the network for all the tracks
for (const auto& [key, val] : clusters) {
Expand All @@ -56,7 +62,8 @@ class AmbiguityTrackClassifier {
networkInput(inputID, 2) = trajState.nOutliers;
networkInput(inputID, 3) = trajState.nHoles;
networkInput(inputID, 4) = trajState.NDF;
networkInput(inputID, 5) = (trajState.chi2Sum * 1.0) / trajState.NDF;
networkInput(inputID, 5) = (trajState.chi2Sum * 1.0) /
(trajState.NDF != 0 ? trajState.NDF : 1);
networkInput(inputID, 6) = Acts::VectorHelpers::eta(track.momentum());
networkInput(inputID, 7) = Acts::VectorHelpers::phi(track.momentum());
inputID++;
Expand Down

0 comments on commit 83b9615

Please sign in to comment.