From 2b19c72e764462f808ec4d836e17889d3c4088bd Mon Sep 17 00:00:00 2001 From: perib Date: Mon, 12 Aug 2024 13:20:09 -0700 Subject: [PATCH] added an attribute so sklearn knows estimator has been fitted. also keep track of feature names --- tpot2/search_spaces/nodes/genetic_feature_selection.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tpot2/search_spaces/nodes/genetic_feature_selection.py b/tpot2/search_spaces/nodes/genetic_feature_selection.py index 97418cd1..28482120 100644 --- a/tpot2/search_spaces/nodes/genetic_feature_selection.py +++ b/tpot2/search_spaces/nodes/genetic_feature_selection.py @@ -19,6 +19,11 @@ def __init__(self, mask): self.mask = mask def fit(self, X, y=None): + self.n_features_in_ = X.shape[1] + if isinstance(X, pd.DataFrame): + self.feature_names_in_ = X.columns + # self.set_output(transform="pandas") + self.is_fitted_ = True #so sklearn knows it's fitted return self def _get_tags(self): @@ -28,6 +33,8 @@ def _get_tags(self): def _get_support_mask(self): return np.array(self.mask) + def get_feature_names_out(self, input_features=None): + return self.feature_names_in_[self.get_support()] class GeneticFeatureSelectorIndividual(SklearnIndividual): def __init__( self,