diff --git a/RecoTauTag/RecoTau/plugins/DeepTauId.cc b/RecoTauTag/RecoTau/plugins/DeepTauId.cc index c38347b3c0125..a5e33f2fa4b8a 100644 --- a/RecoTauTag/RecoTau/plugins/DeepTauId.cc +++ b/RecoTauTag/RecoTau/plugins/DeepTauId.cc @@ -299,9 +299,12 @@ struct CellIndex { class CellGrid { public: - CellGrid(unsigned _nCellsEta, unsigned _nCellsPhi, double _cellSizeEta, double _cellSizePhi) : - nCellsEta(_nCellsEta), nCellsPhi(_nCellsPhi), nTotal(nCellsEta * nCellsPhi), - cellSizeEta(_cellSizeEta), cellSizePhi(_cellSizePhi), cells(nTotal) + using Map = std::map; + using const_iterator = Map::const_iterator; + + CellGrid(unsigned n_cells_eta, unsigned n_cells_phi, double cell_size_eta, double cell_size_phi) : + nCellsEta(n_cells_eta), nCellsPhi(n_cells_phi), nTotal(nCellsEta * nCellsPhi), + cellSizeEta(cell_size_eta), cellSizePhi(cell_size_phi) { if(nCellsEta % 2 != 1 || nCellsEta < 1) throw cms::Exception("DeepTauId") << "Invalid number of eta cells."; @@ -311,10 +314,12 @@ class CellGrid { throw cms::Exception("DeepTauId") << "Invalid cell size."; } - int MaxEtaIndex() const { return static_cast((nCellsEta - 1) / 2); } - int MaxPhiIndex() const { return static_cast((nCellsPhi - 1) / 2); } - double MaxDeltaEta() const { return cellSizeEta * (0.5 + MaxEtaIndex()); } - double MaxDeltaPhi() const { return cellSizePhi * (0.5 + MaxPhiIndex()); } + int maxEtaIndex() const { return static_cast((nCellsEta - 1) / 2); } + int maxPhiIndex() const { return static_cast((nCellsPhi - 1) / 2); } + double maxDeltaEta() const { return cellSizeEta * (0.5 + maxEtaIndex()); } + double maxDeltaPhi() const { return cellSizePhi * (0.5 + maxPhiIndex()); } + int getEtaTensorIndex(const CellIndex& cellIndex) const { return cellIndex.eta + maxEtaIndex(); } + int getPhiTensorIndex(const CellIndex& cellIndex) const { return cellIndex.phi + maxPhiIndex(); } bool TryGetCellIndex(double deltaEta, double deltaPhi, CellIndex& cellIndex) const { @@ -326,28 +331,21 @@ class CellGrid { return true; }; - return getCellIndex(deltaEta, MaxDeltaEta(), cellSizeEta, cellIndex.eta) - && getCellIndex(deltaPhi, MaxDeltaPhi(), cellSizePhi, cellIndex.phi); + return getCellIndex(deltaEta, maxDeltaEta(), cellSizeEta, cellIndex.eta) + && getCellIndex(deltaPhi, maxDeltaPhi(), cellSizePhi, cellIndex.phi); } - Cell& at(const CellIndex& cellIndex) { return cells.at(GetFlatIndex(cellIndex)); } - const Cell& at(const CellIndex& cellIndex) const { return cells.at(GetFlatIndex(cellIndex)); } - bool IsEmpty(const CellIndex& cellIndex) const { return at(cellIndex).empty(); } - -private: - size_t GetFlatIndex(const CellIndex& cellIndex) const - { - if(std::abs(cellIndex.eta) > MaxEtaIndex() || std::abs(cellIndex.phi) > MaxPhiIndex()) - throw cms::Exception("DeepTauId") << "Cell index is out of range"; - const unsigned shiftedEta = static_cast(cellIndex.eta + MaxEtaIndex()); - const unsigned shiftedPhi = static_cast(cellIndex.phi + MaxPhiIndex()); - return shiftedEta * nCellsPhi + shiftedPhi; - } + Cell& operator[](const CellIndex& cellIndex) { return cells[cellIndex]; } + const Cell& at(const CellIndex& cellIndex) const { return cells.at(cellIndex); } + const_iterator begin() const { return cells.begin(); } + const_iterator end() const { return cells.end(); } -private: +public: const unsigned nCellsEta, nCellsPhi, nTotal; const double cellSizeEta, cellSizePhi; - std::vector cells; + +private: + std::map cells; }; } // anonymous namespace @@ -535,7 +533,7 @@ class DeepTauId : public deep_tau::DeepTauBase { const CellObjectType obj_type = GetCellObjectType(obj); CellIndex cell_index; if(grid.TryGetCellIndex(deta, dphi, cell_index)) { - Cell& cell = grid.at(cell_index); + Cell& cell = grid[cell_index]; auto iter = cell.find(obj_type); if(iter != cell.end()) { const auto& prev_obj = objects.at(iter->second); @@ -550,7 +548,7 @@ class DeepTauId : public deep_tau::DeepTauBase { for(size_t n = 0; n < objects.size(); ++n) { const auto& obj = objects.at(n); const double deta = obj.polarP4().eta() - tau.polarP4().eta(); - const double dphi = ROOT::Math::VectorUtil::DeltaPhi(obj.polarP4(), tau.polarP4()); + const double dphi = ROOT::Math::VectorUtil::DeltaPhi(tau.polarP4(), obj.polarP4()); const double dR2 = std::pow(deta, 2) + std::pow(dphi, 2); if(dR2 < inner_dR2) addObject(n, deta, dphi, inner_grid);