Skip to content

Commit

Permalink
added an attribute so sklearn knows estimator has been fitted. also k…
Browse files Browse the repository at this point in the history
…eep track of feature names
  • Loading branch information
perib committed Aug 12, 2024
1 parent 7906c44 commit 2b19c72
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tpot2/search_spaces/nodes/genetic_feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def __init__(self, mask):
self.mask = mask

def fit(self, X, y=None):
self.n_features_in_ = X.shape[1]
if isinstance(X, pd.DataFrame):
self.feature_names_in_ = X.columns
# self.set_output(transform="pandas")
self.is_fitted_ = True #so sklearn knows it's fitted
return self

def _get_tags(self):
Expand All @@ -28,6 +33,8 @@ def _get_tags(self):
def _get_support_mask(self):
return np.array(self.mask)

def get_feature_names_out(self, input_features=None):
return self.feature_names_in_[self.get_support()]

class GeneticFeatureSelectorIndividual(SklearnIndividual):
def __init__( self,
Expand Down

0 comments on commit 2b19c72

Please sign in to comment.