Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Solve the CI issue with the ML solver #2026

Merged
merged 10 commits into from
Apr 20, 2023
3 changes: 0 additions & 3 deletions Examples/Python/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,9 +1119,6 @@ def test_full_chain_odd_example_pythia_geant4(tmp_path):
)


@pytest.mark.skip(
reason="as of https://github.com/acts-project/acts/issues/2023 disabling for now"
)
@pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
@pytest.mark.skipif(not onnxEnabled, reason="ONNX plugin not enabled")
@pytest.mark.slow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@ class AmbiguityTrackClassifier {
std::unordered_map<int, std::vector<int>>& clusters,
const Acts::TrackContainer<track_container_t, traj_t, holder_t>& tracks)
const {
// Compute the number of entry (since it is smaller than the number of
// track)
int trackNb = 0;
for (const auto& [_, val] : clusters) {
trackNb += val.size();
}
// Input of the neural network
Acts::NetworkBatchInput networkInput(tracks.size() + 1, 8);
Acts::NetworkBatchInput networkInput(trackNb, 8);
int inputID = 0;
// Get the input feature of the network for all the tracks
for (const auto& [key, val] : clusters) {
Expand All @@ -56,7 +62,8 @@ class AmbiguityTrackClassifier {
networkInput(inputID, 2) = trajState.nOutliers;
networkInput(inputID, 3) = trajState.nHoles;
networkInput(inputID, 4) = trajState.NDF;
networkInput(inputID, 5) = (trajState.chi2Sum * 1.0) / trajState.NDF;
networkInput(inputID, 5) = (trajState.chi2Sum * 1.0) /
(trajState.NDF != 0 ? trajState.NDF : 1);
networkInput(inputID, 6) = Acts::VectorHelpers::eta(track.momentum());
networkInput(inputID, 7) = Acts::VectorHelpers::phi(track.momentum());
inputID++;
Expand Down