Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HypE #218

Merged
merged 33 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
62e5c10
update README.md
Zhenyu2Liang Jan 13, 2025
fed1e7b
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
Zhenyu2Liang Jan 13, 2025
e2764e7
update README.md
Zhenyu2Liang Jan 13, 2025
46abfb9
update README.md
Zhenyu2Liang Jan 13, 2025
9a1922a
update README.md
Zhenyu2Liang Jan 13, 2025
ce8d618
update README.md
Zhenyu2Liang Jan 13, 2025
99ae7b4
update README.md
Zhenyu2Liang Jan 13, 2025
92862d8
update README.md
Zhenyu2Liang Jan 14, 2025
433b808
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
Zhenyu2Liang Jan 14, 2025
1b9ab34
update README.md
Zhenyu2Liang Jan 14, 2025
bd98832
update README.md
Zhenyu2Liang Jan 14, 2025
0e1e618
add algorithm figures
Zhenyu2Liang Jan 14, 2025
1a7cfee
update README.md
Zhenyu2Liang Jan 14, 2025
1f1730a
update README.md
Zhenyu2Liang Jan 14, 2025
92f53ad
update README.md
Zhenyu2Liang Jan 14, 2025
67b117a
update pso figure
Zhenyu2Liang Jan 14, 2025
0782e09
update pso figure
Zhenyu2Liang Jan 14, 2025
a004f85
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
Zhenyu2Liang Jan 14, 2025
6466fbf
update rvea figure
Zhenyu2Liang Jan 14, 2025
e4f3b4f
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
Zhenyu2Liang Jan 14, 2025
341c1c8
Merge branch 'evoxtorch-main' of https://github.com/EMI-Group/evox in…
Zhenyu2Liang Jan 14, 2025
dc1b742
Merge branch 'main' of https://github.com/EMI-Group/evox into evoxtor…
Zhenyu2Liang Feb 12, 2025
5df0a7e
fix bug of APD calculation
Zhenyu2Liang Feb 13, 2025
675b13b
Merge branch 'main' of https://github.com/EMI-Group/evox into evoxtor…
Zhenyu2Liang Feb 19, 2025
36d4eb4
add hype
Zhenyu2Liang Feb 25, 2025
e0a292e
Merge branch 'main' of https://github.com/EMI-Group/evox into evoxtor…
Zhenyu2Liang Feb 25, 2025
f8faea3
fix bug of device
Zhenyu2Liang Feb 28, 2025
0025871
add hype
Zhenyu2Liang Feb 28, 2025
66ab263
Merge branch 'main' of https://github.com/EMI-Group/evox into evoxtor…
Zhenyu2Liang Feb 28, 2025
5d18fac
Format code
Zhenyu2Liang Feb 28, 2025
3ea35dd
Format code
Zhenyu2Liang Feb 28, 2025
72be76f
Add docs
Zhenyu2Liang Feb 28, 2025
ea5e535
update hype
Zhenyu2Liang Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/evox/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
"SLPSOUS",
# MOEAs
"RVEA",
"MOEAD",
"NSGA2",
"NSGA3",
"MOEAD",
"HypE",
]

from .mo import MOEAD, NSGA2, NSGA3, RVEA
from .mo import MOEAD, NSGA2, NSGA3, RVEA, HypE
from .so import (
ARS,
ASEBO,
Expand Down
3 changes: 2 additions & 1 deletion src/evox/algorithms/mo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["RVEA", "NSGA2", "MOEAD", "NSGA3"]
__all__ = ["HypE", "MOEAD", "NSGA2", "NSGA3", "RVEA"]

from .hype import HypE
from .moead import MOEAD
from .nsga2 import NSGA2
from .nsga3 import NSGA3
Expand Down
131 changes: 131 additions & 0 deletions src/evox/algorithms/mo/hype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Callable, Optional

import torch

from evox.core import Algorithm, Mutable
from evox.operators.crossover import simulated_binary
from evox.operators.mutation import polynomial_mutation
from evox.operators.selection import non_dominate_rank, tournament_selection
from evox.utils import clamp, lexsort


def cal_hv(fit: torch.Tensor, ref: torch.Tensor, pop_size: int, n_sample: int):
n, m = fit.size()
alpha = torch.cumprod(torch.cat([torch.ones(1, device=fit.device), (pop_size - torch.arange(1, n, device=fit.device)) / (n - torch.arange(1, n, device=fit.device))]), dim=0) / torch.arange(1, n + 1, device=fit.device)
alpha = torch.nan_to_num(alpha)

f_min = torch.min(fit, dim=0).values

samples = torch.rand(n_sample, m, device=fit.device) * (ref - f_min) + f_min

ds = torch.zeros(n_sample, dtype=torch.int64, device=fit.device)
pds = (fit.unsqueeze(0).expand(n_sample, -1, -1) - samples.unsqueeze(1).expand(-1, n, -1) <= 0).all(dim=2)
ds = torch.sum(torch.where(pds, ds.unsqueeze(1) + 1, ds.unsqueeze(1)), dim=1)
ds = torch.where(ds == 0, ds, ds - 1)

temp = torch.where(pds.T, ds.unsqueeze(0), -1)
value = torch.where(temp != -1, alpha[temp], torch.tensor(0, dtype=torch.float32))
f = torch.sum(value, dim=1)

f = f * torch.prod(ref - f_min) / n_sample
return f


class HypE(Algorithm):
"""The tensoried version of HypE algorithm.

:reference: https://direct.mit.edu/evco/article-abstract/19/1/45/1363/HypE-An-Algorithm-for-Fast-Hypervolume-Based-Many
"""

def __init__(
self,
pop_size: int,
n_objs: int,
lb: torch.Tensor,
ub: torch.Tensor,
n_sample=10000,
selection_op: Optional[Callable] = None,
mutation_op: Optional[Callable] = None,
crossover_op: Optional[Callable] = None,
device: torch.device | None = None,
):
"""Initializes the HypE algorithm.

:param pop_size: The size of the population.
:param n_objs: The number of objective functions in the optimization problem.
:param lb: The lower bounds for the decision variables (1D tensor).
:param ub: The upper bounds for the decision variables (1D tensor).
:param n_sample: The number of samples for hypervolume calculation (optional).
:param selection_op: The selection operation for evolutionary strategy (optional).
:param mutation_op: The mutation operation, defaults to `polynomial_mutation` if not provided (optional).
:param crossover_op: The crossover operation, defaults to `simulated_binary` if not provided (optional).
:param device: The device on which computations should run (optional). Defaults to PyTorch's default device.
"""

super().__init__()
self.pop_size = pop_size
self.n_objs = n_objs
if device is None:
device = torch.get_default_device()
# check
assert lb.shape == ub.shape and lb.ndim == 1 and ub.ndim == 1
assert lb.dtype == ub.dtype and lb.device == ub.device
self.dim = lb.shape[0]
# write to self
self.lb = lb.to(device=device)
self.ub = ub.to(device=device)
self.n_sample = n_sample

self.selection = selection_op
self.mutation = mutation_op
self.crossover = crossover_op

self.selection = tournament_selection
if self.mutation is None:
self.mutation = polynomial_mutation
if self.crossover is None:
self.crossover = simulated_binary

length = ub - lb
population = torch.rand(self.pop_size, self.dim, device=device)
population = length * population + lb

self.ref = Mutable(torch.ones(n_objs, device=device))

self.pop = Mutable(population)
self.fit = Mutable(torch.full((self.pop_size, self.n_objs), torch.inf, device=device))


def init_step(self):
"""
Perform the initialization step of the workflow.

Calls the `init_step` of the algorithm if overwritten; otherwise, its `step` method will be invoked.
"""
self.fit = self.evaluate(self.pop)
self.ref = torch.full((self.n_objs,), torch.max(self.fit).item() * 1.2, device=self.fit.device)

def step(self):
"""Perform the optimization step of the workflow."""
hv = cal_hv(self.fit, self.ref, self.pop_size, self.n_sample)
mating_pool = self.selection(self.pop_size, -hv)
crossovered = self.crossover(self.pop[mating_pool])
offspring = self.mutation(crossovered, self.lb, self.ub)
offspring = clamp(offspring, self.lb, self.ub)
off_fit = self.evaluate(offspring)

merge_pop = torch.cat([self.pop, offspring], dim=0)
merge_fit = torch.cat([self.fit, off_fit], dim=0)

rank = non_dominate_rank(merge_fit)
order = torch.argsort(rank)
worst_rank = rank[order[self.pop_size - 1]]
mask = rank <= worst_rank

hv = cal_hv(merge_fit, self.ref, torch.sum(mask) - self.pop_size, self.n_sample)
dis = torch.where(mask, hv, -torch.inf)

combined_indices = lexsort([-dis, rank])[: self.pop_size]

self.pop = merge_pop[combined_indices]
self.fit = merge_fit[combined_indices]
5 changes: 2 additions & 3 deletions src/evox/operators/selection/tournament_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ def tournament_selection_multifit(n_round: int, fitnesses: List[torch.Tensor], t
This function performs tournament selection by randomly selecting a group of candidates for each round,
and selecting the best one from each group based on their fitness values across multiple objectives.
"""

fitness_tensor = torch.stack(fitnesses, dim=1)

num_candidates = fitness_tensor.size(0)
parents = torch.randint(0, num_candidates, (n_round, tournament_size))
parents = torch.randint(0, num_candidates, (n_round, tournament_size), device=fitnesses[0].device)
candidates_fitness = fitness_tensor[parents]
candidates_fitness = lexsort(candidates_fitness.unbind(-1))

Expand All @@ -45,7 +44,7 @@ def tournament_selection(n_round: int, fitness: torch.Tensor, tournament_size: i

num_candidates = fitness.size(0)

parents = torch.randint(0, num_candidates, (n_round, tournament_size))
parents = torch.randint(0, num_candidates, (n_round, tournament_size), device=fitness.device)
candidates_fitness = fitness[parents]

winner_indices = torch.argmin(candidates_fitness, dim=1)
Expand Down
7 changes: 4 additions & 3 deletions unit_test/algorithms/test_moea.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from evox.algorithms import MOEAD, NSGA2, NSGA3, RVEA
from evox.algorithms import MOEAD, NSGA2, NSGA3, RVEA, HypE
from evox.core import Algorithm, use_state, vmap
from evox.problems.numerical import DTLZ2
from evox.workflows import StdWorkflow
Expand Down Expand Up @@ -38,15 +38,16 @@ def run_vmap_algorithm(self, algo: Algorithm):

class TestMOVariants(MOTestBase):
def setUp(self):
pop_size = 100
dim = 12
pop_size = 20
dim = 10
lb = -torch.ones(dim)
ub = torch.ones(dim)
self.algo = [
NSGA2(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
NSGA3(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
RVEA(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
MOEAD(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
HypE(pop_size=pop_size, n_objs=3, lb=lb, ub=ub),
]

def test_moea_variants(self):
Expand Down
Loading