From 13079708f9e0705b9bf63a881bc7d55acfdc3f9c Mon Sep 17 00:00:00 2001 From: Mitchell Victoriano Date: Tue, 25 Feb 2025 20:26:36 -0800 Subject: [PATCH] Working through integrating PF generator functions --- .../default/components/analysis_extensions.py | 27 ++------ .../default/components/pf_viewer/__init__.py | 0 .../default/components/pf_viewer/pf_widget.py | 66 +++++++++++++++++++ 3 files changed, 70 insertions(+), 23 deletions(-) create mode 100644 src/badger/gui/default/components/pf_viewer/__init__.py create mode 100644 src/badger/gui/default/components/pf_viewer/pf_widget.py diff --git a/src/badger/gui/default/components/analysis_extensions.py b/src/badger/gui/default/components/analysis_extensions.py index aa26c01b..e22c0d4a 100644 --- a/src/badger/gui/default/components/analysis_extensions.py +++ b/src/badger/gui/default/components/analysis_extensions.py @@ -1,11 +1,11 @@ from abc import abstractmethod from typing import Optional, cast -import pyqtgraph as pg from PyQt5.QtCore import pyqtSignal from PyQt5.QtWidgets import QDialog, QVBoxLayout, QMessageBox from PyQt5.QtGui import QCloseEvent from badger.gui.default.components.bo_visualizer.bo_plotter import BOPlotWidget +from badger.gui.default.components.pf_viewer.pf_widget import ParetoFrontWidget from badger.routine import Routine import logging @@ -36,33 +36,14 @@ def __init__(self, parent: Optional[AnalysisExtension] = None): self.setWindowTitle("Pareto Front Viewer") - self.plot_widget = pg.PlotWidget() - - self.scatter_plot = self.plot_widget.plot(pen=None, symbol="o", symbolSize=10) + self.pw_widget = ParetoFrontWidget() layout = QVBoxLayout() - layout.addWidget(self.plot_widget) + layout.addWidget(self.pw_widget) self.setLayout(layout) def update_window(self, routine: Routine): - if len(routine.vocs.objective_names) != 2: - raise ValueError( - "cannot use pareto front viewer unless there are 2 " "objectives" - ) - - x_name = routine.vocs.objective_names[0] - y_name = routine.vocs.objective_names[1] - - if routine.data is not None: - x = routine.data[x_name] - y = routine.data[y_name] - - # Update the scatter plot - self.scatter_plot.setData(x=x, y=y) - - # set labels - self.plot_widget.setLabel("left", y_name) - self.plot_widget.setLabel("bottom", x_name) + self.pw_widget.update_window(routine) class BOVisualizer(AnalysisExtension): diff --git a/src/badger/gui/default/components/pf_viewer/__init__.py b/src/badger/gui/default/components/pf_viewer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/badger/gui/default/components/pf_viewer/pf_widget.py b/src/badger/gui/default/components/pf_viewer/pf_widget.py new file mode 100644 index 00000000..2e56339f --- /dev/null +++ b/src/badger/gui/default/components/pf_viewer/pf_widget.py @@ -0,0 +1,66 @@ +from copy import deepcopy +from typing import Optional +from PyQt5.QtWidgets import QWidget +from PyQt5.QtWidgets import QVBoxLayout +import pyqtgraph as pg +from badger.routine import Routine + +from xopt.generators.bayesian.mobo import MOBOGenerator + +import logging + +logger = logging.getLogger(__name__) + + +class ParetoFrontWidget(QWidget): + routine = None + + def __init__(self, parent: Optional[QWidget] = None): + super().__init__(parent=parent) + self.plot_widget = pg.PlotWidget() + self.scatter_plot = self.plot_widget.plot(pen=None, symbol="o", symbolSize=10) + layout = QVBoxLayout() + layout.addWidget(self.plot_widget) + self.setLayout(layout) + + def isValidRoutine(self, routine: Routine): + if routine.vocs.objective_names is None: + logging.error("No objective names") + return False + if len(routine.vocs.objective_names) != 2: + logging.error("Invalid number of objectives") + return False + return + + def update_plot(self, routine: Routine): + self.routine = deepcopy(routine) + + if not self.isValidRoutine(self.routine): + logging.error("Invalid routine") + return + + if not isinstance(self.routine.generator, MOBOGenerator): + logging.error("Invalid generator") + return + + pareto_front = self.routine.generator.get_pareto_front() + + if pareto_front == (None, None): + logging.error("No pareto front") + return + + # aquisition_fn = self.routine.generator.get_acquisition(pareto_front) + + x_name = routine.vocs.objective_names[0] + y_name = routine.vocs.objective_names[1] + + if routine.data is not None: + x = routine.data[x_name] + y = routine.data[y_name] + + # Update the scatter plot + self.scatter_plot.setData(x=x, y=y) + + # set labels + self.plot_widget.setLabel("left", y_name) + self.plot_widget.setLabel("bottom", x_name)