-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Save/Load Template to Routine Page #145
Changes from 25 commits
ad13d6e
b754df8
d3493d2
b03d7be
6a618c2
cffe306
295f841
9d47778
ca9fff1
01b80d4
c16f09a
39cb30e
e4276a1
3923e29
ac65fc4
06faa97
9650ac8
c23b544
8de3171
05dff3b
af24733
5d8fa0b
0acd1d4
8e89031
d608fe4
228c720
2933c3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
import numpy as np | ||
import pandas as pd | ||
from PyQt5.QtCore import Qt, pyqtSignal | ||
from PyQt5.QtWidgets import QLineEdit, QLabel, QPushButton | ||
from PyQt5.QtWidgets import QLineEdit, QLabel, QPushButton, QFileDialog | ||
from PyQt5.QtWidgets import QListWidgetItem, QMessageBox, QWidget, QTabWidget | ||
from PyQt5.QtWidgets import QVBoxLayout, QHBoxLayout, QScrollArea | ||
from PyQt5.QtWidgets import QTableWidgetItem, QPlainTextEdit | ||
|
@@ -59,6 +59,8 @@ | |
|
||
class BadgerRoutinePage(QWidget): | ||
sig_updated = pyqtSignal(str, str) # routine name, routine description | ||
sig_load_template = pyqtSignal(str) # template path | ||
sig_save_template = pyqtSignal(str) # template path | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
@@ -98,6 +100,10 @@ def __init__(self): | |
# Trigger the re-rendering of the environment box | ||
self.env_box.relative_to_curr.setChecked(True) | ||
|
||
# Template path | ||
# template_dir = os.path.join(self.BADGER_PLUGIN_ROOT, "templates") | ||
self.template_dir = "/home/physics/mlans/workspace/badger_test/Badger/src/badger/built_in_plugins/templates" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what you have commented is the right idea. @wenatuhs Do we want to keep it like this, or do you want to add a templates directory to your copy of Badger on prod? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that we should go the commented way ( |
||
|
||
def init_ui(self): | ||
config_singleton = init_settings() | ||
|
||
|
@@ -158,6 +164,18 @@ def init_ui(self): | |
vbox_meta.addWidget(descr) | ||
descr_bar.hide() | ||
|
||
# Save Template Button | ||
template_button = QWidget() | ||
hbox_name = QHBoxLayout(template_button) | ||
hbox_name.setContentsMargins(0, 0, 0, 0) | ||
self.save_template_button = save_template_button = QPushButton( | ||
"Save as Template" | ||
) | ||
save_template_button.setFixedWidth(128) | ||
hbox_name.addWidget(save_template_button, alignment=Qt.AlignRight) | ||
vbox_meta.addWidget(template_button, alignment=Qt.AlignBottom) | ||
template_button.show() | ||
|
||
# Tags | ||
self.cbox_tags = cbox_tags = BadgerFilterBox(title=" Tags") | ||
if not strtobool(config_singleton.read_value("BADGER_ENABLE_ADVANCED")): | ||
|
@@ -168,7 +186,9 @@ def init_ui(self): | |
# vbox.addWidget(group_meta) | ||
|
||
# Env box | ||
BADGER_PLUGIN_ROOT = config_singleton.read_value("BADGER_PLUGIN_ROOT") | ||
self.BADGER_PLUGIN_ROOT = BADGER_PLUGIN_ROOT = config_singleton.read_value( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly, you might even just stick this in the constructor since this isn't really a UI thing. |
||
"BADGER_PLUGIN_ROOT" | ||
) | ||
env_dict_dir = os.path.join( | ||
BADGER_PLUGIN_ROOT, "environments", "env_colors.yaml" | ||
) | ||
|
@@ -180,7 +200,8 @@ def init_ui(self): | |
self.env_box = BadgerEnvBox(env_dict, None, self.envs) | ||
scroll_area = QScrollArea() | ||
scroll_area.setFrameShape(QScrollArea.NoFrame) | ||
scroll_area.setStyleSheet(""" | ||
scroll_area.setStyleSheet( | ||
""" | ||
QScrollArea { | ||
border: none; /* Remove border */ | ||
margin: 0px; /* Remove margin */ | ||
|
@@ -189,7 +210,8 @@ def init_ui(self): | |
QScrollArea > QWidget { | ||
margin: 0px; /* Remove margin inside */ | ||
} | ||
""") | ||
""" | ||
) | ||
scroll_content_env = QWidget() | ||
scroll_layout_env = QVBoxLayout(scroll_content_env) | ||
scroll_layout_env.setContentsMargins(0, 0, 15, 0) | ||
|
@@ -208,6 +230,8 @@ def init_ui(self): | |
|
||
def config_logic(self): | ||
self.btn_descr_update.clicked.connect(self.update_description) | ||
self.env_box.load_template_button.clicked.connect(self.load_template_yaml) | ||
self.save_template_button.clicked.connect(self.save_template_yaml) | ||
self.generator_box.cb.currentIndexChanged.connect(self.select_generator) | ||
self.generator_box.btn_docs.clicked.connect(self.open_generator_docs) | ||
self.generator_box.check_use_script.stateChanged.connect(self.toggle_use_script) | ||
|
@@ -232,6 +256,241 @@ def config_logic(self): | |
self.env_box.var_table.sig_sel_changed.connect(self.update_init_table) | ||
self.env_box.var_table.sig_pv_added.connect(self.handle_pv_added) | ||
|
||
def load_template_yaml(self) -> None: | ||
""" | ||
Load data from template .yaml into template_dict dictionary. | ||
This function expects to be called via an action from | ||
a QPushButton | ||
""" | ||
|
||
if isinstance(self.sender(), QPushButton): | ||
# load template from button | ||
options = QFileDialog.Options() | ||
template_path, _ = QFileDialog.getOpenFileName( | ||
self, | ||
"Select YAML File", | ||
self.template_dir, | ||
"YAML Files (*.yaml);;All Files (*)", | ||
options=options, | ||
) | ||
|
||
if not template_path: | ||
return | ||
|
||
# Load template file | ||
try: | ||
with open(template_path, "r") as stream: | ||
template_dict = yaml.safe_load(stream) | ||
self.set_options_from_template(template_dict=template_dict) | ||
self.sig_load_template.emit( | ||
f"Options loaded from template: {os.path.basename(template_path)}" | ||
) | ||
except (FileNotFoundError, yaml.YAMLError) as e: | ||
print(f"Error loading template: {e}") | ||
return | ||
|
||
def set_options_from_template(self, template_dict: dict): | ||
""" | ||
Fills in routine_page GUI with relevant info from template_dict | ||
dictionary | ||
""" | ||
|
||
# Compose the template | ||
try: | ||
name = template_dict["name"] | ||
description = template_dict["description"] | ||
relative_to_current = template_dict["relative_to_current"] | ||
generator_name = template_dict["generator"]["name"] | ||
env_name = template_dict["environment"]["name"] | ||
vrange_limit_options = template_dict["vrange_limit_options"] | ||
initial_point_actions = template_dict[ | ||
"initial_point_actions" | ||
] # should be type: add_curr | ||
critical_constraint_names = template_dict["critical_constraint_names"] | ||
env_params = template_dict["environment"]["params"] | ||
except KeyError as e: | ||
QMessageBox.warning(self, "Error", f"Missing key in template: {e}") | ||
return | ||
|
||
# set vocs | ||
vocs = VOCS( | ||
variables=template_dict["vocs"]["variables"], | ||
objectives=template_dict["vocs"]["objectives"], | ||
constraints=template_dict["vocs"]["constraints"], | ||
constants={}, | ||
observables=template_dict["vocs"]["observables"], | ||
) | ||
|
||
# set description | ||
self.edit_descr.setPlainText(description) | ||
|
||
# set generator | ||
if generator_name in self.generators: | ||
i = self.generators.index(generator_name) | ||
self.generator_box.cb.setCurrentIndex(i) | ||
|
||
filtered_config = filter_generator_config( | ||
generator_name, template_dict["generator"] | ||
) | ||
self.generator_box.edit.setPlainText(get_yaml_string(filtered_config)) | ||
|
||
# set environment | ||
if env_name in self.envs: | ||
i = self.envs.index(env_name) | ||
self.env_box.cb.setCurrentIndex(i) | ||
self.env_box.edit.setPlainText(get_yaml_string(env_params)) | ||
|
||
# set init points based on relative to current | ||
if relative_to_current: | ||
# make sure gui checkbox state matches yaml option | ||
if not self.env_box.relative_to_curr.isChecked(): | ||
self.env_box.relative_to_curr.setChecked(True) | ||
|
||
else: | ||
if self.env_box.relative_to_curr.isChecked(): | ||
self.env_box.relative_to_curr.setChecked(False) | ||
|
||
self.ratio_var_ranges = vrange_limit_options | ||
self.init_table_actions = initial_point_actions | ||
|
||
self.env_box.init_table.clear() | ||
|
||
# set bounds (should this be somewhere else?) | ||
if env_name: | ||
bounds = self.calc_auto_bounds() | ||
self.env_box.var_table.set_bounds(bounds) | ||
|
||
# set selected variables | ||
self.env_box.var_table.set_selected(vocs.variables) | ||
# self.env_box.var_table.set_bounds(vocs.variables) | ||
self.env_box.check_only_var.setChecked(True) | ||
|
||
if not relative_to_current: | ||
# set initial points to sample | ||
self._fill_init_table() | ||
|
||
# set objectives | ||
self.env_box.obj_table.set_selected(vocs.objectives) | ||
self.env_box.obj_table.set_rules(vocs.objectives) | ||
self.env_box.check_only_obj.setChecked(True) | ||
|
||
# set constraints | ||
constraints = vocs.constraints | ||
if len(constraints): | ||
for name, val in constraints.items(): | ||
relation, thres = val | ||
critical = name in critical_constraint_names | ||
relation = ["GREATER_THAN", "LESS_THAN", "EQUAL_TO"].index(relation) | ||
self.add_constraint(name, relation, thres, critical) | ||
else: | ||
self.env_box.list_con.clear() | ||
|
||
# set observables | ||
observables = vocs.observable_names | ||
if len(observables): | ||
for name_sta in observables: | ||
self.add_state(name_sta) | ||
else: | ||
self.env_box.list_obs.clear() | ||
|
||
def generate_template_dict_from_gui(self): | ||
""" | ||
Generate a template dictionary from the current state of the GUI | ||
""" | ||
|
||
vocs, critical_constraints = self._compose_vocs() | ||
|
||
vrange_limit_options = {} | ||
|
||
for var in self.env_box.var_table.variables: | ||
# set bounds to variable range limits | ||
name = next(iter(var)) | ||
if self.env_box.var_table.is_checked(name): | ||
bounds = var[name] | ||
if name in vocs.variables: | ||
vocs.variables[name] = bounds | ||
|
||
# Record the ratio var ranges | ||
if self.env_box.relative_to_curr.isChecked(): | ||
# set all to self.limit_option (I don't think auto mode *currently allows | ||
# setting different ranges for different vars) | ||
for vname in vocs.variables: | ||
vrange_limit_options[vname] = copy.deepcopy(self.limit_option) | ||
else: | ||
# Set vrange_limit_options based on current table info | ||
# Set each based on bounds in table --> convert to percentage of full range | ||
var_bounds = self.env_box.var_table.export_variables() | ||
for var_name in var_bounds: | ||
# get bounds from table | ||
vocs_bounds = vocs.variables[var_name] | ||
bound_range = vocs_bounds[1] - vocs_bounds[0] | ||
desired_bound_range = var_bounds[var_name][1] - var_bounds[var_name][0] | ||
|
||
# calc percentage of full range | ||
ratio_full = desired_bound_range / bound_range | ||
|
||
# calc percentage of current value | ||
# Probably a better way to get current value? | ||
var_curr = var_bounds[var_name][0] + 0.5 * desired_bound_range | ||
ratio_curr = float(desired_bound_range / np.abs(var_curr)) | ||
|
||
vrange_limit_options[var_name] = { | ||
"limit_option_idx": 1, | ||
"ratio_curr": ratio_curr, | ||
"ratio_full": ratio_full, | ||
} | ||
|
||
template_dict = { | ||
"name": self.edit_save.text(), | ||
"description": str(self.edit_descr.toPlainText()), | ||
"relative_to_current": self.env_box.relative_to_curr.isChecked(), | ||
"generator": { | ||
"name": self.generator_box.cb.currentText(), | ||
} | ||
| load_config(self.generator_box.edit.toPlainText()), | ||
"environment": { | ||
"name": self.env_box.cb.currentText(), | ||
"params": load_config(self.env_box.edit.toPlainText()), | ||
}, | ||
"vrange_limit_options": vrange_limit_options, | ||
"initial_point_actions": self.init_table_actions, | ||
"critical_constraint_names": critical_constraints, | ||
"vocs": vars(vocs), | ||
"badger_version": get_badger_version(), | ||
"xopt_version": get_xopt_version(), | ||
} | ||
|
||
return template_dict | ||
|
||
def save_template_yaml(self): | ||
""" | ||
Save the current routine as a template .yaml file | ||
""" | ||
|
||
template_dict = self.generate_template_dict_from_gui() | ||
|
||
options = QFileDialog.Options() | ||
template_path, _ = QFileDialog.getSaveFileName( | ||
self, | ||
"Save Template", | ||
self.template_dir, | ||
"YAML Files (*.yaml);;All Files (*)", | ||
options=options, | ||
) | ||
|
||
if not template_path: | ||
return | ||
|
||
try: | ||
with open(template_path, "w") as stream: | ||
yaml.dump(template_dict, stream) | ||
self.sig_save_template.emit( | ||
f"Current routine options saved to template: {os.path.basename(template_path)}" | ||
) | ||
except (FileNotFoundError, yaml.YAMLError) as e: | ||
print(f"Error saving template: {e}") | ||
return | ||
|
||
def refresh_ui(self, routine: Routine = None, silent: bool = False): | ||
self.routine = routine # save routine for future reference | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing the button definitions are in different places because the actual buttons themselves are as well? Maybe you can show us tomorrow.