Skip to content

Commit

Permalink
Merge pull request #524 from rest-for-physics/jgalan_dataset_updates
Browse files Browse the repository at this point in the history
Component and dataset upgrades
  • Loading branch information
jgalan authored Jun 14, 2024
2 parents 8b97d34 + fe6aad7 commit 9591b5a
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 34 deletions.
16 changes: 11 additions & 5 deletions source/framework/core/inc/TRestDataSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class TRestDataSet : public TRestMetadata {
Bool_t fExternal = false; //<

/// The resulting RDF::RNode object after initialization
ROOT::RDF::RNode fDataSet = ROOT::RDataFrame(0); //!
ROOT::RDF::RNode fDataFrame = ROOT::RDataFrame(0); //!

/// A pointer to the generated tree
TChain* fTree = nullptr; //!
Expand All @@ -122,12 +122,14 @@ class TRestDataSet : public TRestMetadata {
protected:
virtual std::vector<std::string> FileSelection();

void RegenerateTree(std::vector<std::string> finalList = {});

public:
/// Gives access to the RDataFrame
ROOT::RDF::RNode GetDataFrame() const {
if (!fExternal && fTree == nullptr)
RESTWarning << "DataFrame has not been yet initialized" << RESTendl;
return fDataSet;
return fDataFrame;
}

void EnableMultiThreading(Bool_t enable = true) { fMT = enable; }
Expand All @@ -152,7 +154,7 @@ class TRestDataSet : public TRestMetadata {
}

/// Number of variables (or observables)
size_t GetNumberOfColumns() { return fDataSet.GetColumnNames().size(); }
size_t GetNumberOfColumns() { return fDataFrame.GetColumnNames().size(); }

/// Number of variables (or observables)
size_t GetNumberOfBranches() { return GetNumberOfColumns(); }
Expand Down Expand Up @@ -187,7 +189,7 @@ class TRestDataSet : public TRestMetadata {

void SetTotalTimeInSeconds(Double_t seconds) { fTotalDuration = seconds; }
void SetDataFrame(const ROOT::RDF::RNode& dS) {
fDataSet = dS;
fDataFrame = dS;
fExternal = true;
}

Expand All @@ -198,8 +200,12 @@ class TRestDataSet : public TRestMetadata {
void Export(const std::string& filename, std::vector<std::string> excludeColumns = {});

ROOT::RDF::RNode MakeCut(const TRestCut* cut);
ROOT::RDF::RNode ApplyRange(size_t from, size_t to);
ROOT::RDF::RNode Range(size_t from, size_t to);
ROOT::RDF::RNode DefineColumn(const std::string& columnName, const std::string& formula);

size_t GetEntries();

void PrintMetadata() override;
void Initialize() override;

Expand All @@ -209,6 +215,6 @@ class TRestDataSet : public TRestMetadata {
TRestDataSet(const char* cfgFileName, const std::string& name = "");
~TRestDataSet();

ClassDefOverride(TRestDataSet, 7);
ClassDefOverride(TRestDataSet, 8);
};
#endif
78 changes: 60 additions & 18 deletions source/framework/core/src/TRestDataSet.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -382,30 +382,40 @@ void TRestDataSet::GenerateDataSet() {
ROOT::DisableImplicitMT();

RESTInfo << "Initializing dataset" << RESTendl;
fDataSet = ROOT::RDataFrame("AnalysisTree", fFileSelection);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fFileSelection);

RESTInfo << "Making cuts" << RESTendl;
fDataSet = MakeCut(fCut);
fDataFrame = MakeCut(fCut);

// Adding new user columns added to the dataset
for (const auto& [cName, cExpression] : fColumnNameExpressions) {
RESTInfo << "Adding column to dataset: " << cName << RESTendl;
finalList.emplace_back(cName);
fDataSet = DefineColumn(cName, cExpression);
fDataFrame = DefineColumn(cName, cExpression);
}

RegenerateTree(finalList);

RESTInfo << " - Dataset generated!" << RESTendl;
}

///////////////////////////////////////////////
/// \brief It regenerates the tree so that it is an exact copy of the present DataFrame
///
void TRestDataSet::RegenerateTree(std::vector<std::string> finalList) {
RESTInfo << "Generating snapshot." << RESTendl;
std::string user = getenv("USER");
std::string fOutName = "/tmp/rest_output_" + user + ".root";
fDataSet.Snapshot("AnalysisTree", fOutName, finalList);
if (!finalList.empty())
fDataFrame.Snapshot("AnalysisTree", fOutName, finalList);
else
fDataFrame.Snapshot("AnalysisTree", fOutName);

RESTInfo << "Re-importing analysis tree." << RESTendl;
fDataSet = ROOT::RDataFrame("AnalysisTree", fOutName);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fOutName);

TFile* f = TFile::Open(fOutName.c_str());
fTree = (TChain*)f->Get("AnalysisTree");

RESTInfo << " - Dataset generated!" << RESTendl;
}

///////////////////////////////////////////////
Expand Down Expand Up @@ -517,14 +527,32 @@ std::vector<std::string> TRestDataSet::FileSelection() {
return fFileSelection;
}

///////////////////////////////////////////////
/// \brief This method returns a RDataFrame node with the number of
/// samples inside the dataset by selecting a range. It will not
/// modify internally the dataset. See ApplyRange to modify internally
/// the dataset.
///
ROOT::RDF::RNode TRestDataSet::Range(size_t from, size_t to) { return fDataFrame.Range(from, to); }

///////////////////////////////////////////////
/// \brief This method reduces the number of samples inside the
/// dataset by selecting a range.
///
ROOT::RDF::RNode TRestDataSet::ApplyRange(size_t from, size_t to) {
fDataFrame = fDataFrame.Range(from, to);
RegenerateTree();
return fDataFrame;
}

///////////////////////////////////////////////
/// \brief This function applies a TRestCut to the dataframe
/// and returns a dataframe with the applied cuts. Note that
/// the cuts are not applied directly to the dataframe on
/// TRestDataSet, to do so you should do fDataSet = MakeCut(fCut);
/// TRestDataSet, to do so you should do fDataFrame = MakeCut(fCut);
///
ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
auto df = fDataSet;
auto df = fDataFrame;

if (cut == nullptr) return df;

Expand Down Expand Up @@ -561,6 +589,20 @@ ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
return df;
}

///////////////////////////////////////////////
/// \brief It returns the number of entries found inside fDataFrame
/// and prints out a warning if the number of entries inside the
/// tree is not the same.
///
size_t TRestDataSet::GetEntries() {
auto nEntries = fDataFrame.Count();
if (*nEntries == (long long unsigned int)GetTree()->GetEntries()) return *nEntries;
RESTWarning << "TRestDataSet::GetEntries. Number of tree entries is not the same as RDataFrame entries."
<< RESTendl;
RESTWarning << "Returning RDataFrame entries" << RESTendl;
return *nEntries;
}

///////////////////////////////////////////////
/// \brief This function will add a new column to the RDataFrame using
/// the same scheme as the usual RDF::Define method, but it will on top of
Expand All @@ -574,7 +616,7 @@ ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
/// \endcode
///
ROOT::RDF::RNode TRestDataSet::DefineColumn(const std::string& columnName, const std::string& formula) {
auto df = fDataSet;
auto df = fDataFrame;

std::string evalFormula = formula;
for (auto const& [name, properties] : fQuantity)
Expand Down Expand Up @@ -819,7 +861,7 @@ void TRestDataSet::InitFromConfigFile() {
void TRestDataSet::Export(const std::string& filename, std::vector<std::string> excludeColumns) {
RESTInfo << "Exporting dataset" << RESTendl;

std::vector<std::string> columns = fDataSet.GetColumnNames();
std::vector<std::string> columns = fDataFrame.GetColumnNames();
if (!excludeColumns.empty()) {
columns.erase(std::remove_if(columns.begin(), columns.end(),
[&excludeColumns](std::string elem) {
Expand All @@ -831,10 +873,10 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
RESTInfo << "Re-Generating snapshot." << RESTendl;
std::string user = getenv("USER");
std::string fOutName = "/tmp/rest_output_" + user + ".root";
fDataSet.Snapshot("AnalysisTree", fOutName, columns);
fDataFrame.Snapshot("AnalysisTree", fOutName, columns);

RESTInfo << "Re-importing analysis tree." << RESTendl;
fDataSet = ROOT::RDataFrame("AnalysisTree", fOutName);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fOutName);

TFile* f = TFile::Open(fOutName.c_str());
fTree = (TChain*)f->Get("AnalysisTree");
Expand All @@ -846,7 +888,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
RESTInfo << "Re-Generating snapshot." << RESTendl;
std::string user = getenv("USER");
std::string fOutName = "/tmp/rest_output_" + user + ".root";
fDataSet.Snapshot("AnalysisTree", fOutName);
fDataFrame.Snapshot("AnalysisTree", fOutName);

TFile* f = TFile::Open(fOutName.c_str());
fTree = (TChain*)f->Get("AnalysisTree");
Expand Down Expand Up @@ -910,7 +952,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
fprintf(f, "###\n");
fprintf(f, "### Data starts here\n");

auto obsNames = fDataSet.GetColumnNames();
auto obsNames = fDataFrame.GetColumnNames();
std::string obsListStr = "";
for (const auto& l : obsNames) {
if (!obsListStr.empty()) obsListStr += ":";
Expand Down Expand Up @@ -938,7 +980,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>

return;
} else if (TRestTools::GetFileNameExtension(filename) == "root") {
fDataSet.Snapshot("AnalysisTree", filename);
fDataFrame.Snapshot("AnalysisTree", filename);

TFile* f = TFile::Open(filename.c_str(), "UPDATE");
std::string name = this->GetName();
Expand Down Expand Up @@ -1038,7 +1080,7 @@ void TRestDataSet::Import(const std::string& fileName) {
else
ROOT::DisableImplicitMT();

fDataSet = ROOT::RDataFrame("AnalysisTree", fileName);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fileName);

fTree = (TChain*)file->Get("AnalysisTree");
}
Expand Down Expand Up @@ -1104,7 +1146,7 @@ void TRestDataSet::Import(std::vector<std::string> fileNames) {
}

RESTInfo << "Opening list of files. First file: " << fileNames[0] << RESTendl;
fDataSet = ROOT::RDataFrame("AnalysisTree", fileNames);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fileNames);

if (fTree != nullptr) {
delete fTree;
Expand Down
8 changes: 7 additions & 1 deletion source/framework/sensitivity/inc/TRestComponentDataSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class TRestComponentDataSet : public TRestComponent {
/// The dataset used to initialize the distribution
TRestDataSet fDataSet; //!

/// It helps to split large datasets when extracting the parameterization nodes
long long unsigned int fSplitEntries = 600000000;

/// It creates a sample subset using a range definition
TVector2 fDFRange = TVector2(0, 0);

/// It is true of the dataset was loaded without issues
Bool_t fDataSetLoaded = false; //!

Expand Down Expand Up @@ -84,6 +90,6 @@ class TRestComponentDataSet : public TRestComponent {
TRestComponentDataSet(const char* cfgFileName, const std::string& name);
~TRestComponentDataSet();

ClassDefOverride(TRestComponentDataSet, 3);
ClassDefOverride(TRestComponentDataSet, 4);
};
#endif
32 changes: 22 additions & 10 deletions source/framework/sensitivity/src/TRestComponentDataSet.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ void TRestComponentDataSet::PrintMetadata() {
RESTMetadata << " " << RESTendl;
}

if (fDFRange.X() != 0 || fDFRange.Y() != 0) {
RESTMetadata << " DataFrame range: ( " << fDFRange.X() << ", " << fDFRange.Y() << ")" << RESTendl;
RESTMetadata << " " << RESTendl;
}

if (!fParameter.empty() && fParameterizationNodes.empty()) {
RESTMetadata << "This component has no nodes!" << RESTendl;
RESTMetadata << " Use: LoadDataSets() to initialize the nodes" << RESTendl;
Expand Down Expand Up @@ -383,15 +388,17 @@ std::vector<Double_t> TRestComponentDataSet::ExtractParameterizationNodes() {
return vs;
}

auto parValues = fDataSet.GetDataFrame().Take<double>(fParameter);
for (const auto v : parValues) vs.push_back(v);
auto GetUniqueElements = [](const std::vector<double>& vec) {
std::set<double> uniqueSet(vec.begin(), vec.end());
return std::vector<double>(uniqueSet.begin(), uniqueSet.end());
};

std::vector<double>::iterator ip;
ip = std::unique(vs.begin(), vs.begin() + vs.size());
vs.resize(std::distance(vs.begin(), ip));
std::sort(vs.begin(), vs.end());
ip = std::unique(vs.begin(), vs.end());
vs.resize(std::distance(vs.begin(), ip));
for (size_t n = 0; n < 1 + fDataSet.GetEntries() / fSplitEntries; n++) {
auto nEn = fDataSet.Range(n * fSplitEntries, (n + 1) * fSplitEntries).Count();
auto parValues = fDataSet.Range(n * fSplitEntries, (n + 1) * fSplitEntries).Take<double>(fParameter);
std::vector<double> uniqueVec = GetUniqueElements(*parValues);
vs.insert(vs.end(), uniqueVec.begin(), uniqueVec.end());
}

return vs;
}
Expand Down Expand Up @@ -476,6 +483,9 @@ Bool_t TRestComponentDataSet::LoadDataSets() {
fDataSet.Import(fullFileNames);
fDataSetLoaded = true;

if (fDFRange.X() != 0 || fDFRange.Y() != 0)
fDataSet.ApplyRange((size_t)fDFRange.X(), (size_t)fDFRange.Y());

if (fDataSet.GetTree() == nullptr) {
RESTError << "Problem loading dataset from file list :" << RESTendl;
for (const auto& f : fDataSetFileNames) RESTError << " - " << f << RESTendl;
Expand All @@ -486,6 +496,7 @@ Bool_t TRestComponentDataSet::LoadDataSets() {

if (VariablesOk() && WeightsOk()) {
fParameterizationNodes = ExtractParameterizationNodes();
RESTInfo << "Filling histograms" << RESTendl;
FillHistograms();
return fDataSetLoaded;
}
Expand Down Expand Up @@ -515,11 +526,12 @@ Bool_t TRestComponentDataSet::WeightsOk() {
Bool_t ok = true;
std::vector cNames = fDataSet.GetDataFrame().GetColumnNames();

for (const auto& var : fWeights)
if (std::count(cNames.begin(), cNames.end(), var) == 0) {
for (const auto& var : fWeights) {
if (!isANumber(var) && std::count(cNames.begin(), cNames.end(), var) == 0) {
RESTError << "Weight ---> " << var << " <--- NOT found on dataset" << RESTendl;
ok = false;
}
}
return ok;
}

Expand Down

0 comments on commit 9591b5a

Please sign in to comment.