From db3248bb6f979cd4e62f4b5fc834c35e56147453 Mon Sep 17 00:00:00 2001 From: Jay Wang Date: Wed, 9 Nov 2022 04:50:17 -0500 Subject: [PATCH] Add a method to export decision paths Signed-off-by: Jay Wang --- treefarms/model/treefarms.py | 39 ++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/treefarms/model/treefarms.py b/treefarms/model/treefarms.py index 5df6174..b6ab8e1 100644 --- a/treefarms/model/treefarms.py +++ b/treefarms/model/treefarms.py @@ -31,7 +31,7 @@ def __init__(self, configuration={}): self.model_set = None self.dataset = None - # TODO: implement this + # TODO: implement this def load(self, path): """ Parameters @@ -87,7 +87,7 @@ def __train__(self, X, y): self.model_set = ModelSetContainer(result) print(f"training completed. Number of trees in the Rashomon set: {self.model_set.get_tree_count()}") - + def fit(self, X, y): """ @@ -104,7 +104,7 @@ def fit(self, X, y): self.__train__(X, y) return self - # TODO: implement this + # TODO: implement this def predict(self, X): """ Parameters @@ -133,7 +133,7 @@ def __getitem__(self, idx): if self.model_set is None: raise Exception("Error: Model not yet trained") return self.model_set.__getitem__(idx) - + def get_tree_count(self): """Returns the number of trees in the Rashomon set @@ -145,9 +145,10 @@ def get_tree_count(self): if self.model_set is None: raise Exception("Error: Model not yet trained") return self.model_set.get_tree_count() - - def visualize(self, feature_names=None, feature_description=None, *, width=500, height=650): - """Generates a visualization of the Rashomon set using `timbertrek` + + def get_decision_paths(self, feature_names=None, feature_description=None): + """Create a hierarchical dictionary describing the decision paths in the + Rashomon set using `timbertrek`. Parameters --- feature_names : matrix-like, shape = [m_features + 1] @@ -155,20 +156,36 @@ def visualize(self, feature_names=None, feature_description=None, *, width=500, """ if self.model_set is None: raise Exception("Error: Model not yet trained") + + # Convert the trie structure to decision paths trie = self.model_set.to_trie() df = self.dataset if feature_names is None: feature_names = df.columns - + decision_paths = timbertrek.transform_trie_to_rules( trie, df, feature_names=feature_names, feature_description=feature_description, ) - - # return decision_paths - + + return decision_paths + + def visualize(self, feature_names=None, feature_description=None, *, width=500, height=650): + """Generates a visualization of the Rashomon set using `timbertrek` + Parameters + --- + feature_names : matrix-like, shape = [m_features + 1] + a matrix where each row is a sample to be predicted and each column is a feature to be used for prediction + """ + # Get the decision paths + decision_paths = self.get_decision_paths( + feature_names=feature_names, + feature_description=feature_description + ) + + # Show in the in-notebook visualization timbertrek.visualize(decision_paths, width=width, height=height) def __translate__(self, leaves):