Skip to content

Commit

Permalink
[multiproc] Correct entries ranges w/ nWorkers>nEntries
Browse files Browse the repository at this point in the history
When processing trees with less entries than workers with TTreeProcessorMP
some entries were processed multiple times because of a mistake in the
algorithm calculating the event ranges.

Fixes root-project#15425
  • Loading branch information
dpiparo committed Aug 7, 2024
1 parent bc0db53 commit 0818b67
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 35 deletions.
42 changes: 23 additions & 19 deletions tree/treeplayer/inc/TMPWorkerTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,33 +199,37 @@ void TMPWorkerTreeFunc<F>::Process(UInt_t code, MPCodeBufPair &msg)
return;
}

// If we are not done processing entries in the tree,
// create a TTreeReader that reads this range of entries
TTreeReader reader(fTree, enl);
if (start >= 0 && start < fTree->GetEntries()) {
TTreeReader reader(fTree, enl);

TTreeReader::EEntryStatus status = reader.SetEntriesRange(start, finish);
if (status != TTreeReader::kEntryValid) {
reply = sn + "could not set TTreeReader to range " + std::to_string(start) + " " + std::to_string(finish - 1);
MPSend(GetSocket(), MPCode::kProcError, reply.c_str());
return;
}

TTreeReader::EEntryStatus status = reader.SetEntriesRange(start, finish);
if(status != TTreeReader::kEntryValid) {
reply = sn + "could not set TTreeReader to range " + std::to_string(start) + " " + std::to_string(finish - 1);
MPSend(GetSocket(), MPCode::kProcError, reply.c_str());
return;
}
// execute function
auto res = fProcFunc(reader);

//execute function
auto res = fProcFunc(reader);
// detach result from file if needed (currently needed for TH1, TTree, TEventList)
DetachRes(res);

//detach result from file if needed (currently needed for TH1, TTree, TEventList)
DetachRes(res);
if (fCanReduce) {
PoolUtils::ReduceObjects<TObject *> redfunc;
fReducedResult = static_cast<decltype(fReducedResult)>(redfunc(
{res, fReducedResult})); // TODO try not to copy these into a vector, do everything by ref. std::vector<T&>?
} else {
fCanReduce = true;
fReducedResult = res;
}
}

//update the number of processed entries
fProcessedEntries += finish - start;

if(fCanReduce) {
PoolUtils::ReduceObjects<TObject *> redfunc;
fReducedResult = static_cast<decltype(fReducedResult)>(redfunc({res, fReducedResult})); //TODO try not to copy these into a vector, do everything by ref. std::vector<T&>?
} else {
fCanReduce = true;
fReducedResult = res;
}

if(fMaxNEntries == fProcessedEntries)
//we are done forever
MPSend(GetSocket(), MPCode::kProcResult, fReducedResult);
Expand Down
16 changes: 12 additions & 4 deletions tree/treeplayer/src/TMPWorkerTree.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,16 @@ Int_t TMPWorkerTree::LoadTree(UInt_t code, MPCodeBufPair &msg, Long64_t &start,
UInt_t nBunch = nEntries / fNWorkers;
UInt_t rangeN = nProcessed % fNWorkers;
start = rangeN * nBunch;
if (rangeN < (fNWorkers - 1)) {
if (start >= nEntries) {
start = finish = nEntries;
}
else if (rangeN < (fNWorkers - 1)) {
finish = (rangeN+1)*nBunch;
} else {
finish = nEntries;
}


//process tree
tree = fTree;
CloseFile(); // May not be needed
Expand Down Expand Up @@ -389,7 +393,9 @@ Int_t TMPWorkerTree::LoadTree(UInt_t code, MPCodeBufPair &msg, Long64_t &start,
if(nEntries % fNWorkers) nBunch++;
UInt_t rangeN = nProcessed % fNWorkers;
start = rangeN * nBunch;
if(rangeN < (fNWorkers-1))
if (start >= nEntries)
start = finish = nEntries;
else if(rangeN < (fNWorkers-1))
finish = (rangeN+1)*nBunch;
else
finish = nEntries;
Expand All @@ -409,12 +415,14 @@ Int_t TMPWorkerTree::LoadTree(UInt_t code, MPCodeBufPair &msg, Long64_t &start,
if (code == MPCode::kProcRange) {
// example: for 21 entries, 4 workers we want ranges 0-5, 5-10, 10-15, 15-21
// and this worker must take the rangeN-th range
ULong64_t nEntries = (*enl)->GetN();
Long64_t nEntries = (*enl)->GetN();
UInt_t nBunch = nEntries / fNWorkers;
if (nEntries % fNWorkers) nBunch++;
UInt_t rangeN = nProcessed % fNWorkers;
start = rangeN * nBunch;
if (rangeN < (fNWorkers - 1))
if (start >= nEntries) {
start = finish = nEntries;
} else if (rangeN < (fNWorkers - 1))
finish = (rangeN + 1) * nBunch;
else
finish = nEntries;
Expand Down
6 changes: 4 additions & 2 deletions tree/treeplayer/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ if(NOT MSVC OR win_broken_tests)
endif()

if(imt)
ROOT_ADD_GTEST(treeprocessormt treeprocmt/treeprocessormt.cxx LIBRARIES TreePlayer)
if (NOT MSVC)
ROOT_ADD_GTEST(treeprocessors treeprocs/treeprocessors.cxx LIBRARIES TreePlayer)
endif()
if(xrootd)
ROOT_ADD_GTEST(treeprocessormt_remotefiles treeprocmt/treeprocessormt_remotefiles.cxx LIBRARIES TreePlayer)
ROOT_ADD_GTEST(treeprocessormt_remotefiles treeprocs/treeprocessormt_remotefiles.cxx LIBRARIES TreePlayer)
endif()
endif()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
#include <algorithm>
#include <atomic>
#include <chrono>
#include <random>
#include <string>
#include <thread>
#include <utility>

#include <TFile.h>
#include <TParameter.h>
#include <TTree.h>
#include <TSystem.h>
#include <TTreeReader.h>
#include <TTreeReaderValue.h>
#ifndef MSVC
#include <ROOT/TTreeProcessorMP.hxx>
#endif
#include <ROOT/TTreeProcessorMT.hxx>

#include "gtest/gtest.h"

void WriteFiles(const std::vector<std::string> &treenames, const std::vector<std::string> &filenames)
#include <algorithm>
#include <atomic>
#include <chrono>
#include <random>
#include <string>
#include <thread>
#include <utility>

void WriteFiles(const std::vector<std::string> &treenames, const std::vector<std::string> &filenames, int nEvts = 10)
{
int v = 0;
const auto nFiles = filenames.size();
Expand All @@ -27,7 +31,7 @@ void WriteFiles(const std::vector<std::string> &treenames, const std::vector<std
TFile file(fname.c_str(), "recreate");
TTree t(treename.c_str(), treename.c_str());
t.Branch("v", &v);
for (auto e = 0; e < 10; ++e) {
for (auto e = 0; e < nEvts; ++e) {
++v;
t.Fill();
}
Expand Down Expand Up @@ -56,6 +60,70 @@ void DeleteFiles(const std::vector<std::string> &filenames)
gSystem->Unlink(f.c_str());
}

class FilesRAII {
std::vector<std::string> fFileNames;

public:
FilesRAII(const std::vector<std::string> &treenames, const std::vector<std::string> &filenames, int nEvts = 10)
: fFileNames(filenames)
{
WriteFiles(treenames, filenames, nEvts);
}
~FilesRAII() { DeleteFiles(fFileNames); }
};

#ifndef MSVC

class TestSelector : public TSelector {
public:
TParameter<int> fParameter;

virtual void SlaveBegin(TTree *) {}
virtual bool Process(Long64_t)
{
auto newVal = fParameter.GetVal() + 1;
fParameter.SetVal(newVal);
return true;
}
virtual void SlaveTerminate() { GetOutputList()->Add(fParameter.Clone()); }
};

// See issue #15425
TEST(TreeProcessorMP, moreWorkersThanEvents)
{
auto func = [](TTreeReader &r) {
int n = 0;
while (r.Next())
n++;
auto par = new TParameter<int>("n", n);
return par;
};

std::vector<std::string> files = {"f_moreWorkersThanEvents.root"};
std::vector<std::string> trees = {"t"};

FilesRAII fr(trees, files, 2);

// Test function
{
ROOT::TTreeProcessorMP proc(3);
auto res = proc.Process(files, func);

EXPECT_EQ(2, res->GetVal()) << "The counter incremented in the worker processes has the wrong value.";

delete res;
}
// Test selector
{
ROOT::TTreeProcessorMP pool(3);
TestSelector sel;
auto resl = pool.Process(files[0], sel);
auto tparami = (TParameter<int> *)resl->At(0);
EXPECT_EQ(2, tparami->GetVal()) << "The counter incremented in the worker processes has the wrong value.";
}
}
#endif // MSVC

TEST(TreeProcessorMT, EmptyTChain)
{
TChain c("mytree");
Expand Down

0 comments on commit 0818b67

Please sign in to comment.