Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make @transition function-based decorator a class-based decorator #34

Merged
merged 14 commits into from
Oct 10, 2021
2 changes: 1 addition & 1 deletion finite_state_machine/draw_state_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def generate_state_diagram_markdown(cls, initial_state):

class_fns = inspect.getmembers(cls, predicate=inspect.isfunction)
state_transitions: List[Transition] = [
func.__fsm for name, func in class_fns if hasattr(func, "__fsm")
func._fsm for name, func in class_fns if hasattr(func, "_fsm")
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the corresponding change in the draw_state_diagram CLI

]

transition_template = " {source} --> {target} : {name}\n"
Expand Down
239 changes: 125 additions & 114 deletions finite_state_machine/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,119 +23,130 @@ class Transition(NamedTuple):
on_error: Union[bool, int, str]


def transition(source, target, conditions=None, on_error=None):
allowed_types = (str, bool, int, Enum)

if isinstance(source, allowed_types):
source = [source]
if not isinstance(source, list):
raise ValueError("Source can be a bool, int, string, Enum, or list")
for item in source:
if not isinstance(item, allowed_types):
raise ValueError("Source can be a bool, int, string, Enum, or list")
class transition:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lowercase class name allows us to keep the interface the same

# TODO remove from part 1 PR
_fsm_transition_mapping = {}

if not isinstance(target, allowed_types):
raise ValueError("Target needs to be a bool, int or string")

if not conditions:
conditions = []
if not isinstance(conditions, list):
raise ValueError("conditions must be a list")
for condition in conditions:
if not isinstance(condition, types.FunctionType):
raise ValueError("conditions list must contain functions")

if on_error:
if not isinstance(on_error, allowed_types):
raise ValueError("on_error needs to be a bool, int or string")

def transition_decorator(func):
func.__fsm = Transition(func.__name__, source, target, conditions, on_error)

synchronous_execution = not asyncio.iscoroutinefunction(func)
if synchronous_execution:

@functools.wraps(func)
def _wrapper(*args, **kwargs):
try:
self, rest = args
except ValueError:
self = args[0]

if self.state not in source:
exception_message = (
f"Current state is {self.state}. "
f"{func.__name__} allows transitions from {source}."
)
raise InvalidStartState(exception_message)

conditions_not_met = []
for condition in conditions:
if condition(*args, **kwargs) is not True:
conditions_not_met.append(condition)
if conditions_not_met:
raise ConditionsNotMet(conditions_not_met)

if not on_error:
result = func(*args, **kwargs)
self.state = target
return result

try:
result = func(*args, **kwargs)
self.state = target
return result
except Exception:
# TODO should we log this somewhere?
# logger.error? maybe have an optional parameter to set this up
# how to libraries log?
self.state = on_error
return

return _wrapper
else:
def __init__(self, source, target, conditions=None, on_error=None):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way of doing things allows us to process decorator parameters inside of the __init__ function. This feels a bit cleaner than before.

allowed_types = (str, bool, int, Enum)

@functools.wraps(func)
async def _wrapper(*args, **kwargs):
try:
self, rest = args
except ValueError:
self = args[0]

if self.state not in source:
exception_message = (
f"Current state is {self.state}. "
f"{func.__name__} allows transitions from {source}."
)
raise InvalidStartState(exception_message)

conditions_not_met = []
for condition in conditions:
if asyncio.iscoroutinefunction(condition):
condition_result = await condition(*args, **kwargs)
else:
condition_result = condition(*args, **kwargs)
if condition_result is not True:
conditions_not_met.append(condition)
if conditions_not_met:
raise ConditionsNotMet(conditions_not_met)

if not on_error:
result = await func(*args, **kwargs)
self.state = target
return result

try:
result = await func(*args, **kwargs)
self.state = target
return result
except Exception:
# TODO should we log this somewhere?
# logger.error? maybe have an optional parameter to set this up
# how to libraries log?
self.state = on_error
return

return _wrapper

return transition_decorator
if isinstance(source, allowed_types):
source = [source]
if not isinstance(source, list):
raise ValueError("Source can be a bool, int, string, Enum, or list")
for item in source:
if not isinstance(item, allowed_types):
raise ValueError("Source can be a bool, int, string, Enum, or list")
self.source = source

if not isinstance(target, allowed_types):
raise ValueError("Target needs to be a bool, int or string")
self.target = target

if not conditions:
conditions = []
if not isinstance(conditions, list):
raise ValueError("conditions must be a list")
for condition in conditions:
if not isinstance(condition, types.FunctionType):
raise ValueError("conditions list must contain functions")
self.conditions = conditions

if on_error:
if not isinstance(on_error, allowed_types):
raise ValueError("on_error needs to be a bool, int or string")
self.on_error = on_error

def __call__(self, func):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decorators function take functions as input. Our class-based implementation will make our object instance callable and accept a function.

func._fsm = Transition(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the internal location of Transition metadata from .__fsm to ._fsm since the double underscore version was mangling names when introspected in the draw_state_diagram CLI

func.__name__,
self.source,
self.target,
self.conditions,
self.on_error,
)
# TODO update on class transition mapping
self.__class__._fsm_transition_mapping[func.__qualname__] = func

@functools.wraps(func)
def sync_callable(*args, **kwargs):
try:
state_machine, rest = args
except ValueError:
state_machine = args[0]

if state_machine.state not in self.source:
exception_message = (
f"Current state is {state_machine.state}. "
f"{func.__name__} allows transitions from {self.source}."
)
raise InvalidStartState(exception_message)

conditions_not_met = []
for condition in self.conditions:
if condition(*args, **kwargs) is not True:
conditions_not_met.append(condition)
if conditions_not_met:
raise ConditionsNotMet(conditions_not_met)

if not self.on_error:
result = func(*args, **kwargs)
state_machine.state = self.target
return result

try:
result = func(*args, **kwargs)
state_machine.state = self.target
return result
except Exception:
# TODO should we log this somewhere?
# logger.error? maybe have an optional parameter to set this up
# how to libraries log?
state_machine.state = self.on_error
return

@functools.wraps(func)
async def async_callable(*args, **kwargs):
try:
state_machine, rest = args
except ValueError:
state_machine = args[0]

if state_machine.state not in self.source:
exception_message = (
f"Current state is {state_machine.state}. "
f"{func.__name__} allows transitions from {self.source}."
)
raise InvalidStartState(exception_message)

conditions_not_met = []
for condition in self.conditions:
if asyncio.iscoroutinefunction(condition):
condition_result = await condition(*args, **kwargs)
else:
condition_result = condition(*args, **kwargs)
if condition_result is not True:
conditions_not_met.append(condition)
if conditions_not_met:
raise ConditionsNotMet(conditions_not_met)

if not self.on_error:
result = await func(*args, **kwargs)
state_machine.state = self.target
return result

try:
result = await func(*args, **kwargs)
state_machine.state = self.target
return result
except Exception:
# TODO should we log this somewhere?
# logger.error? maybe have an optional parameter to set this up
# how to libraries log?
state_machine.state = self.on_error
return

if asyncio.iscoroutinefunction(func):
return async_callable
else:
return sync_callable