Skip to content

Commit

Permalink
Merge pull request #3 from kandrosov/deepTauId_v2_work
Browse files Browse the repository at this point in the history
Improved CellGrid implementation.
  • Loading branch information
MRD2F authored Apr 20, 2019
2 parents c163c2d + 3ad13e0 commit 98417e2
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions RecoTauTag/RecoTau/plugins/DeepTauId.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,12 @@ struct CellIndex {

class CellGrid {
public:
CellGrid(unsigned _nCellsEta, unsigned _nCellsPhi, double _cellSizeEta, double _cellSizePhi) :
nCellsEta(_nCellsEta), nCellsPhi(_nCellsPhi), nTotal(nCellsEta * nCellsPhi),
cellSizeEta(_cellSizeEta), cellSizePhi(_cellSizePhi), cells(nTotal)
using Map = std::map<CellIndex, Cell>;
using const_iterator = Map::const_iterator;

CellGrid(unsigned n_cells_eta, unsigned n_cells_phi, double cell_size_eta, double cell_size_phi) :
nCellsEta(n_cells_eta), nCellsPhi(n_cells_phi), nTotal(nCellsEta * nCellsPhi),
cellSizeEta(cell_size_eta), cellSizePhi(cell_size_phi)
{
if(nCellsEta % 2 != 1 || nCellsEta < 1)
throw cms::Exception("DeepTauId") << "Invalid number of eta cells.";
Expand All @@ -311,10 +314,12 @@ class CellGrid {
throw cms::Exception("DeepTauId") << "Invalid cell size.";
}

int MaxEtaIndex() const { return static_cast<int>((nCellsEta - 1) / 2); }
int MaxPhiIndex() const { return static_cast<int>((nCellsPhi - 1) / 2); }
double MaxDeltaEta() const { return cellSizeEta * (0.5 + MaxEtaIndex()); }
double MaxDeltaPhi() const { return cellSizePhi * (0.5 + MaxPhiIndex()); }
int maxEtaIndex() const { return static_cast<int>((nCellsEta - 1) / 2); }
int maxPhiIndex() const { return static_cast<int>((nCellsPhi - 1) / 2); }
double maxDeltaEta() const { return cellSizeEta * (0.5 + maxEtaIndex()); }
double maxDeltaPhi() const { return cellSizePhi * (0.5 + maxPhiIndex()); }
int getEtaTensorIndex(const CellIndex& cellIndex) const { return cellIndex.eta + maxEtaIndex(); }
int getPhiTensorIndex(const CellIndex& cellIndex) const { return cellIndex.phi + maxPhiIndex(); }

bool TryGetCellIndex(double deltaEta, double deltaPhi, CellIndex& cellIndex) const
{
Expand All @@ -326,28 +331,21 @@ class CellGrid {
return true;
};

return getCellIndex(deltaEta, MaxDeltaEta(), cellSizeEta, cellIndex.eta)
&& getCellIndex(deltaPhi, MaxDeltaPhi(), cellSizePhi, cellIndex.phi);
return getCellIndex(deltaEta, maxDeltaEta(), cellSizeEta, cellIndex.eta)
&& getCellIndex(deltaPhi, maxDeltaPhi(), cellSizePhi, cellIndex.phi);
}

Cell& at(const CellIndex& cellIndex) { return cells.at(GetFlatIndex(cellIndex)); }
const Cell& at(const CellIndex& cellIndex) const { return cells.at(GetFlatIndex(cellIndex)); }
bool IsEmpty(const CellIndex& cellIndex) const { return at(cellIndex).empty(); }

private:
size_t GetFlatIndex(const CellIndex& cellIndex) const
{
if(std::abs(cellIndex.eta) > MaxEtaIndex() || std::abs(cellIndex.phi) > MaxPhiIndex())
throw cms::Exception("DeepTauId") << "Cell index is out of range";
const unsigned shiftedEta = static_cast<unsigned>(cellIndex.eta + MaxEtaIndex());
const unsigned shiftedPhi = static_cast<unsigned>(cellIndex.phi + MaxPhiIndex());
return shiftedEta * nCellsPhi + shiftedPhi;
}
Cell& operator[](const CellIndex& cellIndex) { return cells[cellIndex]; }
const Cell& at(const CellIndex& cellIndex) const { return cells.at(cellIndex); }
const_iterator begin() const { return cells.begin(); }
const_iterator end() const { return cells.end(); }

private:
public:
const unsigned nCellsEta, nCellsPhi, nTotal;
const double cellSizeEta, cellSizePhi;
std::vector<Cell> cells;

private:
std::map<CellIndex, Cell> cells;
};

} // anonymous namespace
Expand Down Expand Up @@ -535,7 +533,7 @@ class DeepTauId : public deep_tau::DeepTauBase {
const CellObjectType obj_type = GetCellObjectType(obj);
CellIndex cell_index;
if(grid.TryGetCellIndex(deta, dphi, cell_index)) {
Cell& cell = grid.at(cell_index);
Cell& cell = grid[cell_index];
auto iter = cell.find(obj_type);
if(iter != cell.end()) {
const auto& prev_obj = objects.at(iter->second);
Expand All @@ -550,7 +548,7 @@ class DeepTauId : public deep_tau::DeepTauBase {
for(size_t n = 0; n < objects.size(); ++n) {
const auto& obj = objects.at(n);
const double deta = obj.polarP4().eta() - tau.polarP4().eta();
const double dphi = ROOT::Math::VectorUtil::DeltaPhi(obj.polarP4(), tau.polarP4());
const double dphi = ROOT::Math::VectorUtil::DeltaPhi(tau.polarP4(), obj.polarP4());
const double dR2 = std::pow(deta, 2) + std::pow(dphi, 2);
if(dR2 < inner_dR2)
addObject(n, deta, dphi, inner_grid);
Expand Down

0 comments on commit 98417e2

Please sign in to comment.