Skip to content

Commit

Permalink
Merge pull request #145 from michaellans/template
Browse files Browse the repository at this point in the history
Adding Save/Load Template to Routine Page
  • Loading branch information
wenatuhs authored Feb 27, 2025
2 parents 132fc41 + 2933c3e commit d7f8dd7
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 4 deletions.
10 changes: 10 additions & 0 deletions src/badger/gui/acr/components/env_cbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ def init_ui(self):
vbox = QVBoxLayout(self)
vbox.setContentsMargins(8, 8, 8, 8)

# Load Template Button
template_button = QWidget()
template_button.setFixedWidth(128)
hbox_name = QHBoxLayout(template_button)
hbox_name.setContentsMargins(0, 0, 0, 0)
self.load_template_button = load_template_button = QPushButton("Load Template")
hbox_name.addWidget(load_template_button, 0)
vbox.addWidget(template_button)
template_button.show()

self.setObjectName("EnvBox")

name = QWidget()
Expand Down
4 changes: 4 additions & 0 deletions src/badger/gui/acr/components/routine_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class BadgerRoutineEditor(QWidget):
sig_saved = pyqtSignal()
sig_canceled = pyqtSignal()
sig_deleted = pyqtSignal()
sig_load_template = pyqtSignal(str)
sig_save_template = pyqtSignal(str)

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -41,6 +43,8 @@ def init_ui(self):
# scroll_area.setWidgetResizable(True)
# scroll_area.setWidget(routine_page)
stacks.addWidget(routine_page)
routine_page.sig_load_template.connect(self.sig_load_template.emit)
routine_page.sig_save_template.connect(self.sig_save_template.emit)

stacks.setCurrentIndex(1)
vbox.addWidget(stacks)
Expand Down
267 changes: 263 additions & 4 deletions src/badger/gui/acr/components/routine_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import pandas as pd
from PyQt5.QtCore import Qt, pyqtSignal
from PyQt5.QtWidgets import QLineEdit, QLabel, QPushButton
from PyQt5.QtWidgets import QLineEdit, QLabel, QPushButton, QFileDialog
from PyQt5.QtWidgets import QListWidgetItem, QMessageBox, QWidget, QTabWidget
from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QScrollArea
from PyQt5.QtWidgets import QTableWidgetItem, QPlainTextEdit
Expand Down Expand Up @@ -59,6 +59,8 @@

class BadgerRoutinePage(QWidget):
sig_updated = pyqtSignal(str, str) # routine name, routine description
sig_load_template = pyqtSignal(str) # template path
sig_save_template = pyqtSignal(str) # template path

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -98,6 +100,10 @@ def __init__(self):
# Trigger the re-rendering of the environment box
self.env_box.relative_to_curr.setChecked(True)

# Template path
self.template_dir = os.path.join(self.BADGER_PLUGIN_ROOT, "templates")
# self.template_dir = "/home/physics/mlans/workspace/badger_test/Badger/src/badger/built_in_plugins/templates"

def init_ui(self):
config_singleton = init_settings()

Expand Down Expand Up @@ -158,6 +164,18 @@ def init_ui(self):
vbox_meta.addWidget(descr)
descr_bar.hide()

# Save Template Button
template_button = QWidget()
hbox_name = QHBoxLayout(template_button)
hbox_name.setContentsMargins(0, 0, 0, 0)
self.save_template_button = save_template_button = QPushButton(
"Save as Template"
)
save_template_button.setFixedWidth(128)
hbox_name.addWidget(save_template_button, alignment=Qt.AlignRight)
vbox_meta.addWidget(template_button, alignment=Qt.AlignBottom)
template_button.show()

# Tags
self.cbox_tags = cbox_tags = BadgerFilterBox(title=" Tags")
if not strtobool(config_singleton.read_value("BADGER_ENABLE_ADVANCED")):
Expand All @@ -168,7 +186,9 @@ def init_ui(self):
# vbox.addWidget(group_meta)

# Env box
BADGER_PLUGIN_ROOT = config_singleton.read_value("BADGER_PLUGIN_ROOT")
self.BADGER_PLUGIN_ROOT = BADGER_PLUGIN_ROOT = config_singleton.read_value(
"BADGER_PLUGIN_ROOT"
)
env_dict_dir = os.path.join(
BADGER_PLUGIN_ROOT, "environments", "env_colors.yaml"
)
Expand All @@ -180,7 +200,8 @@ def init_ui(self):
self.env_box = BadgerEnvBox(env_dict, None, self.envs)
scroll_area = QScrollArea()
scroll_area.setFrameShape(QScrollArea.NoFrame)
scroll_area.setStyleSheet("""
scroll_area.setStyleSheet(
"""
QScrollArea {
border: none; /* Remove border */
margin: 0px; /* Remove margin */
Expand All @@ -189,7 +210,8 @@ def init_ui(self):
QScrollArea > QWidget {
margin: 0px; /* Remove margin inside */
}
""")
"""
)
scroll_content_env = QWidget()
scroll_layout_env = QVBoxLayout(scroll_content_env)
scroll_layout_env.setContentsMargins(0, 0, 15, 0)
Expand All @@ -208,6 +230,8 @@ def init_ui(self):

def config_logic(self):
self.btn_descr_update.clicked.connect(self.update_description)
self.env_box.load_template_button.clicked.connect(self.load_template_yaml)
self.save_template_button.clicked.connect(self.save_template_yaml)
self.generator_box.cb.currentIndexChanged.connect(self.select_generator)
self.generator_box.btn_docs.clicked.connect(self.open_generator_docs)
self.generator_box.check_use_script.stateChanged.connect(self.toggle_use_script)
Expand All @@ -232,6 +256,241 @@ def config_logic(self):
self.env_box.var_table.sig_sel_changed.connect(self.update_init_table)
self.env_box.var_table.sig_pv_added.connect(self.handle_pv_added)

def load_template_yaml(self) -> None:
"""
Load data from template .yaml into template_dict dictionary.
This function expects to be called via an action from
a QPushButton
"""

if isinstance(self.sender(), QPushButton):
# load template from button
options = QFileDialog.Options()
template_path, _ = QFileDialog.getOpenFileName(
self,
"Select YAML File",
self.template_dir,
"YAML Files (*.yaml);;All Files (*)",
options=options,
)

if not template_path:
return

# Load template file
try:
with open(template_path, "r") as stream:
template_dict = yaml.safe_load(stream)
self.set_options_from_template(template_dict=template_dict)
self.sig_load_template.emit(
f"Options loaded from template: {os.path.basename(template_path)}"
)
except (FileNotFoundError, yaml.YAMLError) as e:
print(f"Error loading template: {e}")
return

def set_options_from_template(self, template_dict: dict):
"""
Fills in routine_page GUI with relevant info from template_dict
dictionary
"""

# Compose the template
try:
name = template_dict["name"]
description = template_dict["description"]
relative_to_current = template_dict["relative_to_current"]
generator_name = template_dict["generator"]["name"]
env_name = template_dict["environment"]["name"]
vrange_limit_options = template_dict["vrange_limit_options"]
initial_point_actions = template_dict[
"initial_point_actions"
] # should be type: add_curr
critical_constraint_names = template_dict["critical_constraint_names"]
env_params = template_dict["environment"]["params"]
except KeyError as e:
QMessageBox.warning(self, "Error", f"Missing key in template: {e}")
return

# set vocs
vocs = VOCS(
variables=template_dict["vocs"]["variables"],
objectives=template_dict["vocs"]["objectives"],
constraints=template_dict["vocs"]["constraints"],
constants={},
observables=template_dict["vocs"]["observables"],
)

# set description
self.edit_descr.setPlainText(description)

# set generator
if generator_name in self.generators:
i = self.generators.index(generator_name)
self.generator_box.cb.setCurrentIndex(i)

filtered_config = filter_generator_config(
generator_name, template_dict["generator"]
)
self.generator_box.edit.setPlainText(get_yaml_string(filtered_config))

# set environment
if env_name in self.envs:
i = self.envs.index(env_name)
self.env_box.cb.setCurrentIndex(i)
self.env_box.edit.setPlainText(get_yaml_string(env_params))

# set init points based on relative to current
if relative_to_current:
# make sure gui checkbox state matches yaml option
if not self.env_box.relative_to_curr.isChecked():
self.env_box.relative_to_curr.setChecked(True)

else:
if self.env_box.relative_to_curr.isChecked():
self.env_box.relative_to_curr.setChecked(False)

self.ratio_var_ranges = vrange_limit_options
self.init_table_actions = initial_point_actions

self.env_box.init_table.clear()

# set bounds (should this be somewhere else?)
if env_name:
bounds = self.calc_auto_bounds()
self.env_box.var_table.set_bounds(bounds)

# set selected variables
self.env_box.var_table.set_selected(vocs.variables)
# self.env_box.var_table.set_bounds(vocs.variables)
self.env_box.check_only_var.setChecked(True)

if not relative_to_current:
# set initial points to sample
self._fill_init_table()

# set objectives
self.env_box.obj_table.set_selected(vocs.objectives)
self.env_box.obj_table.set_rules(vocs.objectives)
self.env_box.check_only_obj.setChecked(True)

# set constraints
constraints = vocs.constraints
if len(constraints):
for name, val in constraints.items():
relation, thres = val
critical = name in critical_constraint_names
relation = ["GREATER_THAN", "LESS_THAN", "EQUAL_TO"].index(relation)
self.add_constraint(name, relation, thres, critical)
else:
self.env_box.list_con.clear()

# set observables
observables = vocs.observable_names
if len(observables):
for name_sta in observables:
self.add_state(name_sta)
else:
self.env_box.list_obs.clear()

def generate_template_dict_from_gui(self):
"""
Generate a template dictionary from the current state of the GUI
"""

vocs, critical_constraints = self._compose_vocs()

vrange_limit_options = {}

for var in self.env_box.var_table.variables:
# set bounds to variable range limits
name = next(iter(var))
if self.env_box.var_table.is_checked(name):
bounds = var[name]
if name in vocs.variables:
vocs.variables[name] = bounds

# Record the ratio var ranges
if self.env_box.relative_to_curr.isChecked():
# set all to self.limit_option (I don't think auto mode *currently allows
# setting different ranges for different vars)
for vname in vocs.variables:
vrange_limit_options[vname] = copy.deepcopy(self.limit_option)
else:
# Set vrange_limit_options based on current table info
# Set each based on bounds in table --> convert to percentage of full range
var_bounds = self.env_box.var_table.export_variables()
for var_name in var_bounds:
# get bounds from table
vocs_bounds = vocs.variables[var_name]
bound_range = vocs_bounds[1] - vocs_bounds[0]
desired_bound_range = var_bounds[var_name][1] - var_bounds[var_name][0]

# calc percentage of full range
ratio_full = desired_bound_range / bound_range

# calc percentage of current value
# Probably a better way to get current value?
var_curr = var_bounds[var_name][0] + 0.5 * desired_bound_range
ratio_curr = float(desired_bound_range / np.abs(var_curr))

vrange_limit_options[var_name] = {
"limit_option_idx": 1,
"ratio_curr": ratio_curr,
"ratio_full": ratio_full,
}

template_dict = {
"name": self.edit_save.text(),
"description": str(self.edit_descr.toPlainText()),
"relative_to_current": self.env_box.relative_to_curr.isChecked(),
"generator": {
"name": self.generator_box.cb.currentText(),
}
| load_config(self.generator_box.edit.toPlainText()),
"environment": {
"name": self.env_box.cb.currentText(),
"params": load_config(self.env_box.edit.toPlainText()),
},
"vrange_limit_options": vrange_limit_options,
"initial_point_actions": self.init_table_actions,
"critical_constraint_names": critical_constraints,
"vocs": vars(vocs),
"badger_version": get_badger_version(),
"xopt_version": get_xopt_version(),
}

return template_dict

def save_template_yaml(self):
"""
Save the current routine as a template .yaml file
"""

template_dict = self.generate_template_dict_from_gui()

options = QFileDialog.Options()
template_path, _ = QFileDialog.getSaveFileName(
self,
"Save Template",
self.template_dir,
"YAML Files (*.yaml);;All Files (*)",
options=options,
)

if not template_path:
return

try:
with open(template_path, "w") as stream:
yaml.dump(template_dict, stream)
self.sig_save_template.emit(
f"Current routine options saved to template: {os.path.basename(template_path)}"
)
except (FileNotFoundError, yaml.YAMLError) as e:
print(f"Error saving template: {e}")
return

def refresh_ui(self, routine: Routine = None, silent: bool = False):
self.routine = routine # save routine for future reference

Expand Down
3 changes: 3 additions & 0 deletions src/badger/gui/acr/pages/home_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def config_logic(self):

self.history_browser.tree_widget.itemSelectionChanged.connect(self.go_run)

self.routine_editor.sig_load_template.connect(self.update_status)
self.routine_editor.sig_save_template.connect(self.update_status)

self.run_monitor.sig_inspect.connect(self.inspect_solution)
self.run_monitor.sig_lock.connect(self.toggle_lock)
self.run_monitor.sig_new_run.connect(self.new_run)
Expand Down

0 comments on commit d7f8dd7

Please sign in to comment.