Skip to content

Commit

Permalink
Working through integrating PF generator functions
Browse files Browse the repository at this point in the history
  • Loading branch information
MitchellAV committed Feb 26, 2025
1 parent 132fc41 commit 1307970
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 23 deletions.
27 changes: 4 additions & 23 deletions src/badger/gui/default/components/analysis_extensions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from abc import abstractmethod
from typing import Optional, cast

import pyqtgraph as pg
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtWidgets import QDialog, QVBoxLayout, QMessageBox
from PyQt5.QtGui import QCloseEvent
from badger.gui.default.components.bo_visualizer.bo_plotter import BOPlotWidget
from badger.gui.default.components.pf_viewer.pf_widget import ParetoFrontWidget
from badger.routine import Routine

import logging
Expand Down Expand Up @@ -36,33 +36,14 @@ def __init__(self, parent: Optional[AnalysisExtension] = None):

self.setWindowTitle("Pareto Front Viewer")

self.plot_widget = pg.PlotWidget()

self.scatter_plot = self.plot_widget.plot(pen=None, symbol="o", symbolSize=10)
self.pw_widget = ParetoFrontWidget()

layout = QVBoxLayout()
layout.addWidget(self.plot_widget)
layout.addWidget(self.pw_widget)
self.setLayout(layout)

def update_window(self, routine: Routine):
if len(routine.vocs.objective_names) != 2:
raise ValueError(
"cannot use pareto front viewer unless there are 2 " "objectives"
)

x_name = routine.vocs.objective_names[0]
y_name = routine.vocs.objective_names[1]

if routine.data is not None:
x = routine.data[x_name]
y = routine.data[y_name]

# Update the scatter plot
self.scatter_plot.setData(x=x, y=y)

# set labels
self.plot_widget.setLabel("left", y_name)
self.plot_widget.setLabel("bottom", x_name)
self.pw_widget.update_window(routine)


class BOVisualizer(AnalysisExtension):
Expand Down
Empty file.
66 changes: 66 additions & 0 deletions src/badger/gui/default/components/pf_viewer/pf_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from copy import deepcopy
from typing import Optional
from PyQt5.QtWidgets import QWidget
from PyQt5.QtWidgets import QVBoxLayout
import pyqtgraph as pg
from badger.routine import Routine

from xopt.generators.bayesian.mobo import MOBOGenerator

import logging

logger = logging.getLogger(__name__)


class ParetoFrontWidget(QWidget):
routine = None

def __init__(self, parent: Optional[QWidget] = None):
super().__init__(parent=parent)
self.plot_widget = pg.PlotWidget()
self.scatter_plot = self.plot_widget.plot(pen=None, symbol="o", symbolSize=10)
layout = QVBoxLayout()
layout.addWidget(self.plot_widget)
self.setLayout(layout)

def isValidRoutine(self, routine: Routine):
if routine.vocs.objective_names is None:
logging.error("No objective names")
return False
if len(routine.vocs.objective_names) != 2:
logging.error("Invalid number of objectives")
return False
return

def update_plot(self, routine: Routine):
self.routine = deepcopy(routine)

if not self.isValidRoutine(self.routine):
logging.error("Invalid routine")
return

if not isinstance(self.routine.generator, MOBOGenerator):
logging.error("Invalid generator")
return

pareto_front = self.routine.generator.get_pareto_front()

if pareto_front == (None, None):
logging.error("No pareto front")
return

# aquisition_fn = self.routine.generator.get_acquisition(pareto_front)

x_name = routine.vocs.objective_names[0]
y_name = routine.vocs.objective_names[1]

if routine.data is not None:
x = routine.data[x_name]
y = routine.data[y_name]

# Update the scatter plot
self.scatter_plot.setData(x=x, y=y)

# set labels
self.plot_widget.setLabel("left", y_name)
self.plot_widget.setLabel("bottom", x_name)

0 comments on commit 1307970

Please sign in to comment.