Skip to content

Commit

Permalink
Run selection works, 2 compatible doesnt
Browse files Browse the repository at this point in the history
  • Loading branch information
KrissiHub committed Jan 8, 2024
1 parent 857c17b commit 3a21fae
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 45 deletions.
22 changes: 6 additions & 16 deletions deepcave/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def plugin_input_update(pathname: str, *inputs_list: str) -> List[str]:
elif inputs is not None:
# We have to update the options of the run selection here.
# This is important if the user has added/removed runs.
print("hallo")
if self.activate_run_selection:
run_value = inputs["run"]["value"]
new_inputs = self.__class__.load_run_inputs(
Expand Down Expand Up @@ -506,6 +505,7 @@ def _process_raw_outputs(
if mpl_active:
outputs = self.__class__.load_mpl_outputs(passed_runs, cleaned_inputs, passed_outputs)
else:
("Hallo")
outputs = self.__class__.load_outputs(passed_runs, cleaned_inputs, passed_outputs)

logger.debug("Raw outputs processed successfully.")
Expand Down Expand Up @@ -803,21 +803,11 @@ def __call__(self, render_button: bool = False) -> List[Component]:
else:
components += [html.H1(self.name)]

#If the runs for cost over time are not compatible
#it should still be possible to look at them separatly
if self.id == "cost_over_time":
try:
self.check_runs_compatibility(self.all_runs)
except NotMergeableError as message:
notification.update("The runs you chose could not be combined. You can still choose to look at the Cost Over Time for one specific run though.")
self.activate_run_selection = True
return components
else:
try:
self.check_runs_compatibility(self.all_runs)
except NotMergeableError as message:
notification.update(str(message))
return components
try:
self.check_runs_compatibility(self.all_runs)
except NotMergeableError as message:
notification.update(str(message))
return components

if self.activate_run_selection:
run_input_layout = [self.__class__.get_run_input_layout(self.register_input)]
Expand Down
127 changes: 99 additions & 28 deletions deepcave/plugins/objective/cost_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
get_hovertext_from_config,
save_image,
)
from deepcave.runs.group import NotMergeableError
from deepcave import notification


class CostOverTime(DynamicPlugin):
id = "cost_over_time"
name = "Cost Over Time"
icon = "fas fa-chart-line"
activate_run_selection = False
activate_run_selection = True
help = "docs/plugins/cost_over_time.rst"

def check_runs_compatibility(self, runs: List[AbstractRun]) -> None:
check_equality(runs, objectives=True, budgets=True)
#If the runs are not mergeable, there should still
#be an option to look at one of the runs
try:
check_equality(runs, objectives=True, budgets=True)
except NotMergeableError:
notification.update("The runs you chose could not be combined. You can still choose to look at the Cost Over Time for one specific run though.")

# Set some attributes here
run = runs[0]

objective_names = run.get_objective_names()
objective_ids = run.get_objective_ids()
self.objective_options = get_select_options(objective_names, objective_ids)
Expand Down Expand Up @@ -163,21 +169,17 @@ def load_outputs(runs, inputs, outputs):
return go.Figure()

traces = []
for idx, run in enumerate(runs):
if run.prefix == "group" and not show_groups:
continue

if run.prefix != "group" and not show_runs:
continue
if isinstance(runs, AbstractRun):
run = runs

objective = run.get_objective(inputs["objective_id"])
config_ids = outputs[run.id]["config_ids"]
x = outputs[run.id]["times"]
config_ids = outputs["config_ids"]
x = outputs["times"]
if inputs["xaxis"] == "trials":
x = outputs[run.id]["ids"]

y = np.array(outputs[run.id]["costs_mean"])
y_err = np.array(outputs[run.id]["costs_std"])
x = outputs["ids"]
y = np.array(outputs["costs_mean"])
y_err = np.array(outputs["costs_std"])
y_upper = list(y + y_err)
y_lower = list(y - y_err)
y = list(y)
Expand All @@ -198,7 +200,7 @@ def load_outputs(runs, inputs, outputs):
y=y,
name=run.name,
line_shape="hv",
line=dict(color=get_color(idx)),
line=dict(color=get_color(0)),
hovertext=hovertext,
hoverinfo=hoverinfo,
marker=dict(symbol=symbol),
Expand All @@ -207,30 +209,99 @@ def load_outputs(runs, inputs, outputs):
)

traces.append(
go.Scatter(
x=x,
y=y_upper,
line=dict(color=get_color(idx, 0)),
line_shape="hv",
hoverinfo="skip",
showlegend=False,
marker=dict(symbol=None),
go.Scatter(
x=x,
y=y_upper,
line=dict(color=get_color(0, 0)),
line_shape="hv",
hoverinfo="skip",
showlegend=False,
marker=dict(symbol=None),
)
)
)


traces.append(
go.Scatter(
x=x,
y=y_lower,
fill="tonexty",
fillcolor=get_color(idx, 0.2),
line=dict(color=get_color(idx, 0)),
fillcolor=get_color(0, 0.2),
line=dict(color=get_color(0, 0)),
line_shape="hv",
hoverinfo="skip",
showlegend=False,
marker=dict(symbol=None),
)
)
else:
for idx, run in enumerate(runs):
if run.prefix == "group" and not show_groups:
continue

if run.prefix != "group" and not show_runs:
continue

objective = run.get_objective(inputs["objective_id"])
config_ids = outputs[run.id]["config_ids"]
x = outputs[run.id]["times"]
if inputs["xaxis"] == "trials":
x = outputs[run.id]["ids"]

y = np.array(outputs[run.id]["costs_mean"])
y_err = np.array(outputs[run.id]["costs_std"])
y_upper = list(y + y_err)
y_lower = list(y - y_err)
y = list(y)

hovertext = ""
hoverinfo = "skip"
symbol = None
mode = "lines"
if len(config_ids) > 0:
hovertext = [get_hovertext_from_config(run, config_id) for config_id in config_ids]
hoverinfo = "text"
symbol = "circle"
mode = "lines+markers"

traces.append(
go.Scatter(
x=x,
y=y,
name=run.name,
line_shape="hv",
line=dict(color=get_color(idx)),
hovertext=hovertext,
hoverinfo=hoverinfo,
marker=dict(symbol=symbol),
mode=mode,
)
)

traces.append(
go.Scatter(
x=x,
y=y_upper,
line=dict(color=get_color(idx, 0)),
line_shape="hv",
hoverinfo="skip",
showlegend=False,
marker=dict(symbol=None),
)
)

traces.append(
go.Scatter(
x=x,
y=y_lower,
fill="tonexty",
fillcolor=get_color(idx, 0.2),
line=dict(color=get_color(idx, 0)),
line_shape="hv",
hoverinfo="skip",
showlegend=False,
marker=dict(symbol=None),
)
)

if objective is None:
raise PreventUpdate
Expand Down
1 change: 0 additions & 1 deletion deepcave/runs/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def get_run(self, run_id: str) -> AbstractRun:
for run in runs:
if run.id == run_id:
return run

raise RuntimeError("Run not found.")

def get_groups(self) -> List[Group]:
Expand Down

0 comments on commit 3a21fae

Please sign in to comment.