diff --git a/examples/4_advising_models/advise_by_surrogate.py b/examples/4_advising_models/advise_by_surrogate.py index a60e355d..f0083bde 100644 --- a/examples/4_advising_models/advise_by_surrogate.py +++ b/examples/4_advising_models/advise_by_surrogate.py @@ -1,7 +1,4 @@ -"""Example on advising pipelines for a dataset. - -Run this module from the repo root. -""" +""" Example on advising pipelines for a dataset. """ import pickle @@ -22,22 +19,25 @@ SURROGATE_MODEL_HYPERPARAMETERS_FILE = get_checkpoints_dir() / "tabular" / "hparams.yaml" META_FEATURES_EXTRACTOR_CONFIG_FILE = get_configs_dir() / "use_features.json" META_FEATURES_PREPROCESSOR_FILE = ( - get_data_dir() / "pymfe_meta_features_and_fedot_pipelines/all/meta_features_preprocessors.pickle" + get_data_dir() / "pymfe_meta_features_and_fedot_pipelines/all/meta_features_preprocessors.pickle" ) PIPELINES_FILE = get_data_dir() / "pymfe_meta_features_and_fedot_pipelines/all/pipelines_fedot.pickle" def preprocess_dataset_features( - dataset: DatasetBase, - dataset_meta_features_extractor: MetaFeaturesExtractor, - dataset_meta_features_preprocessor: FeaturesPreprocessor, + dataset: DatasetBase, + dataset_meta_features_extractor: MetaFeaturesExtractor, + dataset_meta_features_preprocessor: FeaturesPreprocessor, ) -> Data: """Extract dataset features. Parameters ---------- - dataset: DatasetBase - Dataset to extract features from. + dataset: Dataset to extract features from. + + dataset_meta_features_extractor: Extractor returning dataset meta-features. + + dataset_meta_features_preprocessor: Preprocessor, preparing the meta-features after extraction. Returns ------- @@ -60,22 +60,16 @@ def main(): ) # Build dataset meta features extractor. extractor_params = get_extractor_params(META_FEATURES_EXTRACTOR_CONFIG_FILE) - dataset_meta_features_extractor = PymfeExtractor(extractor_params) + dataset_meta_features_extractor = PymfeExtractor(**extractor_params) # Build dataset meta features preprocessor. - dataset_meta_features_preprocessor = FeaturesPreprocessor( - load_path=META_FEATURES_PREPROCESSOR_FILE, extractor_params=extractor_params - ) + dataset_meta_features_preprocessor = FeaturesPreprocessor(load_path=META_FEATURES_PREPROCESSOR_FILE) # Build pipeline features extractor. pipeline_extractor = FEDOTPipelineFeaturesExtractor( include_operations_hyperparameters=False, operation_encoding="ordinal", ) # Build adviser. - advisor = SurrogateGNNModelAdvisor( - surrogate_model, - dataset_meta_features_extractor, - dataset_meta_features_preprocessor, - ) + advisor = SurrogateGNNModelAdvisor(surrogate_model) # Load datasets. datasets_names = [ "apsfailure", @@ -87,7 +81,9 @@ def main(): # Extract features pipelines_features = [pipeline_extractor(pipeline.save()[0]) for pipeline in pipelines] - datasets_features = [preprocess_dataset_features(dataset) for dataset in datasets] + datasets_features = [ + preprocess_dataset_features(dataset, dataset_meta_features_extractor, dataset_meta_features_preprocessor) for + dataset in datasets] # Make prediction. return advisor.predict(pipelines, datasets, pipelines_features, datasets_features, k=5)