Skip to content

Commit

Permalink
Adding a method for local instance-level explanations
Browse files Browse the repository at this point in the history
  • Loading branch information
groshanlal committed Jan 3, 2024
1 parent a3cc70b commit f40368c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
2 changes: 1 addition & 1 deletion te2rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

__version__ = "0.7.0"
__version__ = "0.8.0"
80 changes: 75 additions & 5 deletions te2rules/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,11 @@ def explain(
min_precision=min_precision,
jaccard_threshold=jaccard_threshold,
)
rules = self.rule_builder.explain(X, y)
rules_as_str = [str(r) for r in rules]
# rules_as_str = self._prune_rules_by_dropping_terms(rules, X, y, min_precision)
return rules_as_str

self.rules = [str(r) for r in self.rule_builder.explain(X, y)]
self.longer_rules = [str(r) for r in self.rule_builder.longer_rules]

return self.rules

def _prune_rules_by_dropping_terms(
self,
Expand Down Expand Up @@ -389,6 +390,60 @@ def get_fidelity(

return self.rule_builder.get_fidelity()

def explain_instance_with_rules(
self, X: List[List[float]], explore_all_rules: bool = True
) -> List[List[str]]:
"""
A method to explain the model output for a list of inputs using rules.
For each instance in the list, if the model output is positive, this method
returns a corresponding list of rules that explain that instance.
For any instance in the list, for which the model output is negative,
this method returns an empty list corresponding to that instance.
Returns a list of explanations corresponding to each input.
Each explanation is a list of possible rules that can explain
the corresponding instance.
Parameters
----------
X: 2d numpy.array
2 dimensional data with feature values that can be sent
to the model for predicting outcome.
explore_all_rules: boolean, optional
optional boolean variable guiding the algorithm's behavior.
When set to True, the algorithm considers all possible rules (longer_rules)
extracted by the explain() function before employing set cover to select
a condensed subset of rules. When set to False, the algorithm considers
only the condensed subset of rules returned by explain().
By default, the function utilizes all possible rules (longer_rules) obtained
through explain().
Returns
-------
list of explaining rules for each instance, with each explanation presented
as a list of rules each of which can independently explain the instance.
"""
dataframe = pd.DataFrame(X, columns=self.feature_names)
if explore_all_rules is True:
rules = self.longer_rules
else:
rules = self.rules

rule_support = []
for r in rules:
support = dataframe.query(str(r)).index.tolist()
rule_support.append(support)

selected_rules: List[List[str]] = []
for i in range(len(dataframe)):
selected_rules.append([])
for i in range(len(rule_support)):
for j in rule_support[i]:
selected_rules[j].append(rules[i])

return selected_rules


class RuleBuilder:
"""
Expand Down Expand Up @@ -497,6 +552,8 @@ class prediction from the tree ensemble model using input data
self.solution_rules = self._deduplicate(self.solution_rules)
log.info(str(len(self.solution_rules)) + " solutions")

self.longer_rules = [r for r in self.solution_rules]

# log.info("")
# log.info("Removing subset rules")
# self._remove_subset_rules()
Expand Down Expand Up @@ -548,7 +605,20 @@ def _rules_to_cover_positives(self, positives: List[int]) -> None:
original_rules[max_coverage_rule].decision_rule
):
max_coverage_rule = rule

"""
scores = self._score_rule_using_data(
original_rules[max_coverage_rule], self.labels
)
avg_score = 0.0
if len(scores) > 0:
avg_score = sum(scores) / len(scores)
print(
int(len(set(self.positives).intersection(set(original_rules[max_coverage_rule].decision_support)))/len(set(self.positives))*100*10)/10,
int(avg_score*100*10)/10,
max_coverage_rule
)
"""
selected_rules.append(original_rules[max_coverage_rule])
new_covered_positives = positive_coverage[max_coverage_rule]
covered_positives = list(
Expand Down

0 comments on commit f40368c

Please sign in to comment.