Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error msg and routine renaming (with bugfix) #90

Merged
merged 10 commits into from
Sep 11, 2024
34 changes: 30 additions & 4 deletions src/badger/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,32 @@ def save_routine(routine: Routine):

# This function is not safe and might break database! Use with caution!
@maybe_create_routines_db
def update_routine(routine: Routine):
@maybe_create_runs_db
def update_routine(routine: Routine, old_name=''):
db_routine = os.path.join(BADGER_DB_ROOT, 'routines.db')

con = sqlite3.connect(db_routine)
cur = con.cursor()

name = old_name if old_name else routine.name
cur.execute('select * from routine where name=:name',
{'name': routine.name})
{'name': name})
record = cur.fetchone()

if record: # update the record
cur.execute('update routine set config = ?, savedAt = ? where name = ?',
(routine.yaml(), datetime.now(), routine.name))
cur.execute('update routine set name = ?, config = ?, savedAt = ? where name = ?',
(routine.name, routine.yaml(), datetime.now(), name))

if old_name:
db_run = os.path.join(BADGER_DB_ROOT, 'runs.db')

con_run = sqlite3.connect(db_run, timeout=30.0)
cur_run = con_run.cursor()

cur_run.execute('update run set routine = ? where routine = ?',(routine.name, old_name))

con_run.commit()
con_run.close()

con.commit()
con.close()
Expand Down Expand Up @@ -300,6 +313,19 @@ def remove_run_by_id(rid):
con.commit()
con.close()

@maybe_create_runs_db
def get_routine_name_by_filename(filename):
db_run = os.path.join(BADGER_DB_ROOT, 'runs.db')

con = sqlite3.connect(db_run)
cur = con.cursor()

cur.execute(f'select routine from run where filename = "{filename}"')
routine_name = cur.fetchone()[0]
con.close()

return routine_name


def import_routines(filename):
con = sqlite3.connect(filename)
Expand Down
2 changes: 2 additions & 0 deletions src/badger/gui/default/components/routine_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def cancel_create_routine(self):
self.sig_canceled.emit()

def save_routine(self):
# here save() is not a property/attribute
# it's a method that also calls _compose_routine()
if self.routine_page.save() == 0:
self.sig_saved.emit()

Expand Down
4 changes: 4 additions & 0 deletions src/badger/gui/default/components/routine_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,7 @@ def update_tooltip(self):
def update_description(self, descr):
self.description = descr
self.update_tooltip()

def update_name(self, name):
self.name = name
self.update_tooltip()
38 changes: 30 additions & 8 deletions src/badger/gui/default/components/routine_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import partial
import os
import yaml
import json

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -36,6 +37,7 @@
from ..windows.var_dialog import BadgerVariableDialog
from ..windows.add_random_dialog import BadgerAddRandomDialog
from ..windows.message_dialog import BadgerScrollableMessageBox
from ..windows.expandable_message_box import ExpandableMessageBox
from ..utils import filter_generator_config
from ....db import save_routine, remove_routine, update_routine
from ....environment import instantiate_env
Expand All @@ -53,7 +55,8 @@


class BadgerRoutinePage(QWidget):
sig_updated = pyqtSignal(str, str) # routine name, routine description
name_updated = pyqtSignal(str, str) # routine name, routine new name
descr_updated = pyqtSignal(str, str) # routine name, routine description

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -967,7 +970,7 @@ def update_description(self):
try:
update_routine(routine)
# Notify routine list to update
self.sig_updated.emit(routine.name, routine.description)
self.descr_updated.emit(routine.name, routine.description)
QMessageBox.information(
self,
'Update success!',
Expand All @@ -983,15 +986,34 @@ def update_description(self):
def save(self):
try:
routine = self._compose_routine()
except ValidationError:
return QMessageBox.critical(
self,
'Error!',
traceback.format_exc()
except ValidationError as e:
error_message = "".join([error['msg']+'\n\n' for error in e.errors()]).strip()
details = traceback.format_exc()
dialog = ExpandableMessageBox(
title="Error!",
text=error_message,
detailedText=details,
parent=self
)
dialog.setIcon(QMessageBox.Critical)
dialog.exec_()
return

try:
save_routine(routine)
if self.routine and routine != self.routine:
old_dict = json.loads(self.routine.json())
old_dict.pop('data')
new_dict = json.loads(routine.json())
new_dict.pop('data')
new_dict['name'] = old_dict['name']
new_dict['description'] = new_dict['description']
if old_dict == new_dict:
update_routine(routine, old_name=self.routine.name)
self.name_updated.emit(self.routine.name, routine.name)
else:
save_routine(routine)
else:
save_routine(routine)
except sqlite3.IntegrityError:
return QMessageBox.critical(
self, 'Error!',
Expand Down
22 changes: 20 additions & 2 deletions src/badger/gui/default/pages/home_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
list_routine,
load_routine,
remove_routine,
get_routine_name_by_filename
)
from ....settings import read_value
from ....utils import get_header, strtobool
Expand Down Expand Up @@ -275,9 +276,10 @@ def config_logic(self):
self.routine_editor.sig_saved.connect(self.routine_saved)
self.routine_editor.sig_canceled.connect(self.done_create_routine)
self.routine_editor.sig_deleted.connect(self.routine_deleted)
self.routine_editor.routine_page.sig_updated.connect(
self.routine_editor.routine_page.descr_updated.connect(
self.routine_description_updated
)
self.routine_editor.routine_page.name_updated.connect(self.routine_name_updated)

# Assign shortcuts
self.shortcut_go_search = QShortcut(QKeySequence("Ctrl+L"), self)
Expand Down Expand Up @@ -347,7 +349,11 @@ def build_routine_list(
self, routine_names: List[str], timestamps: List[str], environments: List[str], descriptions: List[str]
):
try:
selected_routine = self.prev_routine_item.routine_name
if self.prev_routine_item.routine_name in routine_names:
selected_routine = self.prev_routine_item.routine_name
else:
self.prev_routine_item = None
selected_routine = None
except Exception:
selected_routine = None
self.routine_list.clear()
Expand Down Expand Up @@ -419,6 +425,9 @@ def go_run(self, i: int):
run_filename = get_base_run_filename(self.cb_history.currentText())
try:
_routine = load_run(run_filename)
_routine.name = get_routine_name_by_filename(run_filename) # in case name changed
# if self.current_routine:
# _routine.name = self.current_routine.name
routine, _ = load_routine(_routine.name) # get the initial routine
# TODO: figure out how to recover the original routine
if routine is None: # routine not found, could be deleted
Expand Down Expand Up @@ -596,6 +605,15 @@ def routine_description_updated(self, name, descr):
if routine_item.name == name:
routine_item.update_description(descr)
break

def routine_name_updated(self, old_name, new_name):
for i in range(self.routine_list.count()):
item = self.routine_list.item(i)
if item is not None:
routine_item = self.routine_list.itemWidget(item)
if routine_item.name == old_name:
routine_item.update_name(new_name)
break

def export_routines(self):
options = QFileDialog.Options()
Expand Down
85 changes: 85 additions & 0 deletions src/badger/gui/default/windows/expandable_message_box.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import sys
from PyQt5.QtWidgets import (QDialog, QMessageBox, QVBoxLayout, QHBoxLayout,
QLabel, QPushButton, QTextEdit)
from PyQt5.QtGui import QTextOption, QFont, QFontDatabase

class ExpandableMessageBox(QDialog):
def __init__(self, icon=None, title="Message", text="", detailedText="", parent=None):
super().__init__(parent)

# Main layout
mainLayout = QVBoxLayout(self)

# Top layout for icon and main text
topLayout = QHBoxLayout()
self.iconLabel = QLabel()
if icon:
self.iconLabel.setPixmap(icon.pixmap(64, 64))
self.textLabel = QLabel(text)
self.textLabel.setMinimumWidth(280)
self.textLabel.setWordWrap(True)
font = QFont()
font.setBold(True)
self.textLabel.setFont(font)
topLayout.addWidget(self.iconLabel)
topLayout.addWidget(self.textLabel, 1)
mainLayout.addLayout(topLayout)

# Detailed text area setup
self.detailedTextWidget = QTextEdit()
self.detailedTextWidget.setText(detailedText)
self.detailedTextWidget.setReadOnly(True)
self.detailedTextWidget.setWordWrapMode(QTextOption.WrapAtWordBoundaryOrAnywhere)
monoFont = QFontDatabase.systemFont(QFontDatabase.FixedFont)
monoFont.setPointSize(9)
self.detailedTextWidget.setFont(monoFont)
self.detailedTextWidget.setVisible(False) # Initially hidden

# Button to show/hide details
self.toggleButton = QPushButton("Show Details")
self.toggleButton.clicked.connect(self.toggle_details)

# Layouts for the detailed text and button
self.detailLayout = QVBoxLayout()
self.detailLayout.addWidget(self.detailedTextWidget)
self.detailLayout.addWidget(self.toggleButton)

# Add the detailed text area and toggle button to the main layout
mainLayout.addLayout(self.detailLayout)

# Buttons
self.buttonBox = QHBoxLayout()
self.okButton = QPushButton("OK")
self.okButton.clicked.connect(self.accept)
self.buttonBox.addWidget(self.okButton)
mainLayout.addLayout(self.buttonBox)

# Set window properties
self.setWindowTitle(title)
self.resize(420, 250)

def toggle_details(self):
if self.detailedTextWidget.isVisible():
self.detailedTextWidget.setVisible(False)
self.toggleButton.setText("Show Details")
else:
self.detailedTextWidget.setVisible(True)
self.toggleButton.setText("Hide Details")

def setText(self, text):
self.textLabel.setText(text)

def setDetailedText(self, detailedText):
self.detailedTextWidget.setText(detailedText)

def setIcon(self, icon):
# This maps the QMessageBox icons to the QDialog
iconMap = {
QMessageBox.Information: QMessageBox.standardIcon(QMessageBox.Information),
QMessageBox.Warning: QMessageBox.standardIcon(QMessageBox.Warning),
QMessageBox.Critical: QMessageBox.standardIcon(QMessageBox.Critical),
QMessageBox.Question: QMessageBox.standardIcon(QMessageBox.Question),
}
standardIcon = iconMap.get(icon)
if standardIcon:
self.iconLabel.setPixmap(standardIcon)
24 changes: 23 additions & 1 deletion src/badger/routine.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that we got several duplicated lines possibly due to conflict merging, not a big deal though. I'll just merge and fix it :)

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from copy import deepcopy
from typing import Any, List, Optional

import numpy as np
import numpy as np
import pandas as pd
from pandas import DataFrame
Expand All @@ -17,6 +18,8 @@
from xopt.generators import get_generator
from xopt.utils import get_local_region
from badger.utils import curr_ts
from xopt.utils import get_local_region
from badger.utils import curr_ts
from badger.environment import Environment, instantiate_env
from badger.utils import curr_ts

Expand All @@ -34,6 +37,11 @@ class Routine(Xopt):
vrange_limit_options: Optional[dict] = Field(None)
initial_point_actions: Optional[List] = Field(None)
additional_variables: Optional[List[str]] = Field([])
# Store relative to current params
relative_to_current: Optional[bool] = Field(False)
vrange_limit_options: Optional[dict] = Field(None)
initial_point_actions: Optional[List] = Field(None)
additional_variables: Optional[List[str]] = Field([])

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -156,7 +164,21 @@ def json(self, **kwargs) -> str:
except AttributeError:
pass

return json.dumps(dict_result)
return json.dumps(dict_result)

def __eq__(self, routine):
if not isinstance(routine, Routine):
return False
self_dict = json.loads(self.json())
self_dict.pop('data')
routine_dict = json.loads(routine.json())
routine_dict.pop('data')
return self_dict == routine_dict

def __hash__(self):
self_dict = json.loads(self.json())
self_dict.pop('data')
return hash(tuple(sorted(self_dict)))


def calculate_variable_bounds(limit_options, vocs, env):
Expand Down
Loading