diff --git a/crystal_toolkit/helpers/utils.py b/crystal_toolkit/helpers/utils.py index 4a8f75df..f19fb4d5 100644 --- a/crystal_toolkit/helpers/utils.py +++ b/crystal_toolkit/helpers/utils.py @@ -4,7 +4,7 @@ import re import urllib.parse from fractions import Fraction -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal from uuid import uuid4 import dash @@ -519,7 +519,7 @@ def hook_up_fig_with_struct_viewer( fig: go.Figure, df: pd.DataFrame, struct_col: str = "structure", - validate_id: Callable[[str], bool] = lambda id: True, + transform_id: Callable[[str], str | Literal[False]] = lambda mat_id: mat_id, highlight_selected: Callable[[dict[str, Any]], dict[str, Any]] | None = None, ) -> Dash: """Create a Dash app that hooks up a Plotly figure with a Crystal Toolkit structure @@ -555,11 +555,10 @@ def hook_up_fig_with_struct_viewer( struct_col (str, optional): Name of the column in the data frame that contains the structures. Defaults to 'structure'. Can be instances of pymatgen.core.Structure or dicts created with Structure.as_dict(). - validate_id (Callable[[str], bool], optional): Function that takes a string + transform_id (Callable[[str], str | False], optional): Function that takes a string extracted from the hovertext key of a hoverData event payload and returns - True if the string is a valid df row index. Defaults to lambda - id: True. Useful for not running the update-structure - callback on unexpected data. + a string that can be used to index the dataframe. Return False to prevent + the update-structure callback from being called. highlight_selected (Callable[[dict[str, Any]], dict[str, Any]], optional): Function that takes the clicked or last-hovered point and returns a dict of kwargs to be passed to go.Figure.add_annotation() to highlight said point. @@ -642,9 +641,8 @@ def update_structure( # hover_data and click_data are identical since a hover event always precedes a # click so we always use hover_data - material_id = hover_data["points"][0]["hovertext"] - if not validate_id(material_id): - print(f"bad {material_id=}") + material_id = transform_id(hover_data["points"][0]["hovertext"]) + if material_id is False: raise dash.exceptions.PreventUpdate struct = df[struct_col][material_id]