diff --git a/golem/core/optimisers/genetic/gp_params.py b/golem/core/optimisers/genetic/gp_params.py index 7579dc96..8132d6fc 100644 --- a/golem/core/optimisers/genetic/gp_params.py +++ b/golem/core/optimisers/genetic/gp_params.py @@ -83,8 +83,7 @@ class GPAlgorithmParameters(AlgorithmParameters): adaptive_mutation_type: MutationAgentTypeEnum = MutationAgentTypeEnum.default context_agent_type: Union[ContextAgentTypeEnum, Callable] = ContextAgentTypeEnum.nodes_num - selection_types: Sequence[SelectionTypesEnum] = \ - (SelectionTypesEnum.tournament,) + selection_types: Sequence[Union[SelectionTypesEnum, Any]] = None crossover_types: Sequence[Union[CrossoverTypesEnum, Any]] = \ (CrossoverTypesEnum.one_point,) mutation_types: Sequence[Union[MutationTypesEnum, Any]] = simple_mutation_set @@ -96,7 +95,9 @@ class GPAlgorithmParameters(AlgorithmParameters): window_size: Optional[int] = None def __post_init__(self): + if not self.selection_types: + self.selection_types = (SelectionTypesEnum.spea2,) if self.multi_objective \ + else (SelectionTypesEnum.tournament,) if self.multi_objective: - self.selection_types = (SelectionTypesEnum.spea2,) # TODO add possibility of using regularization in MO alg self.regularization_type = RegularizationTypesEnum.none diff --git a/golem/core/optimisers/genetic/operators/selection.py b/golem/core/optimisers/genetic/operators/selection.py index 14850a83..c2f1ab72 100644 --- a/golem/core/optimisers/genetic/operators/selection.py +++ b/golem/core/optimisers/genetic/operators/selection.py @@ -33,6 +33,8 @@ def _selection_by_type(selection_type: SelectionTypesEnum) -> Callable[[Populati } if selection_type in selections: return selections[selection_type] + elif isinstance(selection_type, Callable): + return selection_type else: raise ValueError(f'Required selection not found: {selection_type}') diff --git a/test/unit/optimizers/gp_operators/test_selection.py b/test/unit/optimizers/gp_operators/test_selection.py index b0a8e67e..922fe03f 100644 --- a/test/unit/optimizers/gp_operators/test_selection.py +++ b/test/unit/optimizers/gp_operators/test_selection.py @@ -1,9 +1,15 @@ from golem.core.adapter import DirectAdapter from golem.core.optimisers.genetic.gp_params import GPAlgorithmParameters +from golem.core.optimisers.genetic.operators.operator import PopulationT from golem.core.optimisers.genetic.operators.selection import Selection, SelectionTypesEnum, random_selection from golem.core.optimisers.opt_history_objects.individual import Individual from test.unit.optimizers.test_evaluation import get_objective from test.unit.utils import graph_first, graph_second, graph_third, graph_fourth, graph_fifth +from random import sample + + +def custom_selection(population: PopulationT, pop_size: int): + return sample(population, pop_size) def get_population(): @@ -56,3 +62,13 @@ def test_individuals_selection_equality_individuals(): selected_individuals_ref = [str(ind) for ind in selected_individuals] assert (len(selected_individuals) == num_of_inds and len(set(selected_individuals_ref)) == 1) + + +def test_custom_selection(): + num_of_inds = 3 + population = get_population() + requirements = GPAlgorithmParameters(selection_types=[custom_selection], pop_size=num_of_inds) + selection = Selection(requirements) + selected_individuals = selection(population) + assert (all([ind in population for ind in selected_individuals]) and + len(selected_individuals) == num_of_inds)