Skip to content

Commit

Permalink
Support constraints and states
Browse files Browse the repository at this point in the history
  • Loading branch information
zhe-slac committed Mar 15, 2024
1 parent 66be893 commit fdaf3c0
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 68 deletions.
22 changes: 11 additions & 11 deletions src/badger/gui/default/components/routine_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ def refresh_ui(self, routine: Routine = None):
'EQUAL_TO'].index(relation)
self.add_constraint(name, relation, thres, critical)

constants = routine.vocs.constants
if len(constants):
for name_sta, val in constants.items():
observables = routine.vocs.observable_names
if len(observables):
for name_sta in observables:
self.add_state(name_sta)

# Config the metadata
Expand Down Expand Up @@ -547,19 +547,19 @@ def _compose_vocs(self) -> (VOCS, List[str]):
if critical:
critical_constraints.append(con_name)

states = {}
states = []
for i in range(self.env_box.list_sta.count()):
raise NotImplementedError("constants/states has not been implemented yet!")
#item = self.env_box.list_sta.item(i)
#item_widget = self.env_box.list_sta.itemWidget(item)
#sta_name = item_widget.cb_sta.currentText()
#states[sta_name] =
item = self.env_box.list_sta.item(i)
item_widget = self.env_box.list_sta.itemWidget(item)
sta_name = item_widget.cb_sta.currentText()
states.append(sta_name)

vocs = VOCS(
variables=variables,
objectives=objectives,
constraints={},
constants={}
constraints=constraints,
constants={},
observables=states,
)

return vocs, critical_constraints
Expand Down
116 changes: 61 additions & 55 deletions src/badger/gui/default/components/run_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def __init__(self):

@property
def vocs(self) -> VOCS:
return self.routine.vocs
if self.routine is None:
return None
else:
return self.routine.vocs

def init_ui(self):
# Load all icons
Expand Down Expand Up @@ -404,18 +407,20 @@ def init_plots(self, routine: Routine = None, run_filename: str = None):
self.plot_var.clear()
self.plot_obj.clear()

# if constraints are active clear them
# if constraints are active delete them
try:
self.plot_con.clear()
self.plot_con.addItem(self.inspector_constraint)
except AttributeError:
self.monitor.removeItem(self.plot_con)
self.plot_con.removeItem(self.inspector_constraint)
del self.plot_con
except:
pass

# if statics exist clear that plot
# if statics exist delete that plot
try:
self.plot_sta.clear()
self.plot_sta.addItem(self.inspector_state)
except AttributeError:
self.monitor.removeItem(self.plot_sta)
self.plot_sta.removeItem(self.inspector_state)
del self.plot_sta
except:
pass

# if no routine is loaded set button to disabled
Expand All @@ -427,6 +432,8 @@ def init_plots(self, routine: Routine = None, run_filename: str = None):
self.btn_opt.setDisabled(True)
self.btn_set.setDisabled(True)

self.routine = None

return

self.routine = routine
Expand All @@ -435,7 +442,7 @@ def init_plots(self, routine: Routine = None, run_filename: str = None):
objective_names = self.vocs.objective_names
variable_names = self.vocs.variable_names
constraint_names = self.vocs.constraint_names
sta_names = self.vocs.constant_names
sta_names = self.vocs.observable_names

# Configure variable plots
self.curves_variable = self._configure_plot(
Expand All @@ -453,12 +460,11 @@ def init_plots(self, routine: Routine = None, run_filename: str = None):
self.plot_con
except:
self.plot_con = plot_con = add_axes(
self.monitor, "Constraints", 'Evaluation History (C)',
self.monitor, "constraints", 'Evaluation History (C)',
self.inspector_constraint, row=1, col=0
)
plot_con.setXLink(self.plot_obj)

# Configure objective plots
self.curves_constraint = self._configure_plot(
self.plot_con, self.inspector_constraint, constraint_names
)
Expand All @@ -477,18 +483,14 @@ def init_plots(self, routine: Routine = None, run_filename: str = None):
self.plot_sta
except:
self.plot_sta = plot_sta = add_axes(
self.monitor, "Constants", 'Evaluation History (S)',
self.inspector_state, row=1, col=0
self.monitor, "states", 'Evaluation History (S)',
self.inspector_state, row=2, col=0
)
plot_sta.setXLink(self.plot_obj)

self.curves_sta = []
for i, sta_name in enumerate(sta_names):
color = self.colors[i % len(self.colors)]
symbol = self.symbols[i % len(self.colors)]
_curve = self.plot_sta.plot(pen=pg.mkPen(color, width=3),
name=sta_name)
self.curves_sta.append(_curve)
self.curves_sta = self._configure_plot(
self.plot_sta, self.inspector_state, sta_names
)
else:
try:
self.monitor.removeItem(self.plot_sta)
Expand Down Expand Up @@ -595,8 +597,8 @@ def enable_auto_range(self):
if self.vocs.constraint_names:
self.plot_con.enableAutoRange()

# if self.sta_names:
# self.plot_sta.enableAutoRange()
if self.vocs.observable_names:
self.plot_sta.enableAutoRange()

def open_extensions_palette(self):
self.extensions_palette.show()
Expand Down Expand Up @@ -657,8 +659,7 @@ def update_curves(self):
set_data(variable_names, self.curves_variable, input_data, ts)
set_data(self.vocs.objective_names, self.curves_objective, data_copy, ts)
set_data(self.vocs.constraint_names, self.curves_constraint, data_copy, ts)

# TODO: add tracking of observables
set_data(self.vocs.observable_names, self.curves_sta, data_copy, ts)

def check_critical(self):
"""
Expand Down Expand Up @@ -690,10 +691,11 @@ def check_critical(self):
if len(self.routine.critical_constraint_names) == 0:
return
else:
feas = self.vocs.feasibility_data(self.routine.data.iloc[-1])
feas = self.vocs.feasibility_data(self.routine.data.iloc[-1], prefix='')
# print(feas[self.routine.critical_constraint_names], self.vocs)
violated_critical = ~feas[self.routine.critical_constraint_names].any()

if not violated_critical:
if not violated_critical.item():
return

# if code reaches this point there is a critical constraint violated
Expand Down Expand Up @@ -819,27 +821,27 @@ def ins_obj_dragged(self, ins_obj):
self.inspector_variable.setValue(ins_obj.value())
if self.vocs.constraint_names:
self.inspector_constraint.setValue(ins_obj.value())
# if self.sta_names:
# self.inspector_state.setValue(ins_obj.value())
if self.vocs.observable_names:
self.inspector_state.setValue(ins_obj.value())

def ins_con_dragged(self, ins_con):
self.inspector_variable.setValue(ins_con.value())
self.inspector_objective.setValue(ins_con.value())
# if self.sta_names:
# self.inspector_state.setValue(ins_con.value())
if self.vocs.observable_names:
self.inspector_state.setValue(ins_con.value())

def ins_sta_dragged(self, ins_sta):
self.inspector_variable.setValue(ins_sta.value())
self.inspector_objective.setValue(ins_sta.value())
# if self.vocs.constraint_names:
# self.inspector_constraint.setValue(ins_sta.value())
if self.vocs.constraint_names:
self.inspector_constraint.setValue(ins_sta.value())

def ins_var_dragged(self, ins_var):
self.inspector_objective.setValue(ins_var.value())
if self.vocs.constraint_names:
self.inspector_constraint.setValue(ins_var.value())
# if self.sta_names:
# self.inspector_state.setValue(ins_var.value())
if self.vocs.observable_names:
self.inspector_state.setValue(ins_var.value())

def ins_drag_done(self, ins):
self.sync_ins(ins.value())
Expand All @@ -848,13 +850,16 @@ def sync_ins(self, pos):
if self.plot_x_axis: # x-axis is time
value, idx = self.closest_ts(pos)
else:
ts = self.extract_timestamp()
value = idx = np.clip(np.round(pos), 0, len(ts) - 1)
try:
ts = self.extract_timestamp()
value = idx = np.clip(np.round(pos), 0, len(ts) - 1)
except: # no data
value = idx = np.round(pos)
self.inspector_objective.setValue(value)
if self.vocs.constraint_names:
if self.vocs and self.vocs.constraint_names:
self.inspector_constraint.setValue(value)
# if self.sta_names:
# self.inspector_state.setValue(value)
if self.vocs and self.vocs.observable_names:
self.inspector_state.setValue(value)
self.inspector_variable.setValue(value)

self.sig_inspect.emit(idx)
Expand Down Expand Up @@ -894,6 +899,7 @@ def jump_to_optimal(self):
try:
best_idx, _ = self.routine.vocs.select_best(
self.routine.sorted_data, n=1)
# print(best_idx, _)
best_idx = int(best_idx[0])

self.jump_to_solution(best_idx)
Expand All @@ -914,8 +920,8 @@ def jump_to_solution(self, idx):
self.inspector_objective.setValue(value)
if self.vocs.constraint_names:
self.inspector_constraint.setValue(value)
# if self.sta_names:
# self.inspector_state.setValue(value)
if self.vocs.observable_names:
self.inspector_state.setValue(value)
self.inspector_variable.setValue(value)

def set_vars(self):
Expand Down Expand Up @@ -952,15 +958,15 @@ def select_x_axis(self, i):
self.plot_obj.setLabel('bottom', 'time (s)')
if self.vocs.constraint_names:
self.plot_con.setLabel('bottom', 'time (s)')
# if self.sta_names:
# self.plot_sta.setLabel('bottom', 'time (s)')
if self.vocs.observable_names:
self.plot_sta.setLabel('bottom', 'time (s)')
else:
self.plot_var.setLabel('bottom', 'iterations')
self.plot_obj.setLabel('bottom', 'iterations')
if self.vocs.constraint_names:
self.plot_con.setLabel('bottom', 'iterations')
# if self.sta_names:
# self.plot_sta.setLabel('bottom', 'iterations')
if self.vocs.observable_names:
self.plot_sta.setLabel('bottom', 'iterations')

# Update inspector line position
if i:
Expand All @@ -971,8 +977,8 @@ def select_x_axis(self, i):
self.inspector_objective.setValue(value)
if self.vocs.constraint_names:
self.inspector_constraint.setValue(value)
# if self.sta_names:
# self.inspector_state.setValue(value)
if self.vocs.observable_names:
self.inspector_state.setValue(value)
self.inspector_variable.setValue(value)

self.update_curves()
Expand All @@ -989,18 +995,18 @@ def toggle_x_plot_y_axis_relative(self):
def on_mouse_click(self, event):
# https://stackoverflow.com/a/64081483
coor_obj = self.plot_obj.vb.mapSceneToView(event._scenePos)
if self.vocs.constraint_names:
if self.vocs and self.vocs.constraint_names:
coor_con = self.plot_con.vb.mapSceneToView(event._scenePos)
# if self.sta_names:
# coor_sta = self.plot_sta.vb.mapSceneToView(event._scenePos)
if self.vocs and self.vocs.observable_names:
coor_sta = self.plot_sta.vb.mapSceneToView(event._scenePos)
coor_var = self.plot_var.vb.mapSceneToView(event._scenePos)

flag = self.plot_obj.viewRect().contains(coor_obj) or \
self.plot_var.viewRect().contains(coor_var)
if self.vocs.constraint_names:
self.plot_var.viewRect().contains(coor_var)
if self.vocs and self.vocs.constraint_names:
flag = flag or self.plot_con.viewRect().contains(coor_con)
# if self.sta_names:
# flag = flag or self.plot_sta.viewRect().contains(coor_sta)
if self.vocs and self.vocs.observable_names:
flag = flag or self.plot_sta.viewRect().contains(coor_sta)

if flag:
self.sync_ins(coor_obj.x())
Expand Down
5 changes: 3 additions & 2 deletions src/badger/routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ def validate_data(cls, v, info: ValidationInfo):
@property
def sorted_data(self):
data_copy = deepcopy(self.data)
data_copy.index = data_copy.index.astype(int)
data_copy.sort_index(inplace=True)
if data_copy is not None:
data_copy.index = data_copy.index.astype(int)
data_copy.sort_index(inplace=True)

return data_copy

Expand Down

0 comments on commit fdaf3c0

Please sign in to comment.