Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make @transition function-based decorator a class-based decorator #34

Merged
merged 14 commits into from
Oct 10, 2021
6 changes: 3 additions & 3 deletions finite_state_machine/draw_state_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from typing import List

from finite_state_machine.state_machine import Transition
from finite_state_machine.state_machine import TransitionDetails


def import_state_machine_class(path): # pragma: no cover
Expand All @@ -28,8 +28,8 @@ def generate_state_diagram_markdown(cls, initial_state):
"""

class_fns = inspect.getmembers(cls, predicate=inspect.isfunction)
state_transitions: List[Transition] = [
func.__fsm for name, func in class_fns if hasattr(func, "__fsm")
state_transitions: List[TransitionDetails] = [
func._fsm for name, func in class_fns if hasattr(func, "_fsm")
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the corresponding change in the draw_state_diagram CLI

]

transition_template = " {source} --> {target} : {name}\n"
Expand Down
234 changes: 120 additions & 114 deletions finite_state_machine/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,127 +15,133 @@ def __init__(self):
raise ValueError("Need to set a state instance variable")


class Transition(NamedTuple):
class TransitionDetails(NamedTuple):
name: str
source: Union[list, bool, int, str]
target: Union[bool, int, str]
conditions: list
on_error: Union[bool, int, str]


def transition(source, target, conditions=None, on_error=None):
allowed_types = (str, bool, int, Enum)
class transition:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lowercase class name allows us to keep the interface the same

def __init__(self, source, target, conditions=None, on_error=None):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way of doing things allows us to process decorator parameters inside of the __init__ function. This feels a bit cleaner than before.

allowed_types = (str, bool, int, Enum)

if isinstance(source, allowed_types):
source = [source]
if not isinstance(source, list):
raise ValueError("Source can be a bool, int, string, Enum, or list")
for item in source:
if not isinstance(item, allowed_types):
if isinstance(source, allowed_types):
source = [source]
if not isinstance(source, list):
raise ValueError("Source can be a bool, int, string, Enum, or list")

if not isinstance(target, allowed_types):
raise ValueError("Target needs to be a bool, int or string")

if not conditions:
conditions = []
if not isinstance(conditions, list):
raise ValueError("conditions must be a list")
for condition in conditions:
if not isinstance(condition, types.FunctionType):
raise ValueError("conditions list must contain functions")

if on_error:
if not isinstance(on_error, allowed_types):
raise ValueError("on_error needs to be a bool, int or string")

def transition_decorator(func):
func.__fsm = Transition(func.__name__, source, target, conditions, on_error)

synchronous_execution = not asyncio.iscoroutinefunction(func)
if synchronous_execution:

@functools.wraps(func)
def _wrapper(*args, **kwargs):
try:
self, rest = args
except ValueError:
self = args[0]

if self.state not in source:
exception_message = (
f"Current state is {self.state}. "
f"{func.__name__} allows transitions from {source}."
)
raise InvalidStartState(exception_message)

conditions_not_met = []
for condition in conditions:
if condition(*args, **kwargs) is not True:
conditions_not_met.append(condition)
if conditions_not_met:
raise ConditionsNotMet(conditions_not_met)

if not on_error:
result = func(*args, **kwargs)
self.state = target
return result

try:
result = func(*args, **kwargs)
self.state = target
return result
except Exception:
# TODO should we log this somewhere?
# logger.error? maybe have an optional parameter to set this up
# how to libraries log?
self.state = on_error
return

return _wrapper
for item in source:
if not isinstance(item, allowed_types):
raise ValueError("Source can be a bool, int, string, Enum, or list")
self.source = source

if not isinstance(target, allowed_types):
raise ValueError("Target needs to be a bool, int or string")
self.target = target

if not conditions:
conditions = []
if not isinstance(conditions, list):
raise ValueError("conditions must be a list")
for condition in conditions:
if not isinstance(condition, types.FunctionType):
raise ValueError("conditions list must contain functions")
self.conditions = conditions

if on_error:
if not isinstance(on_error, allowed_types):
raise ValueError("on_error needs to be a bool, int or string")
self.on_error = on_error

def __call__(self, func):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decorators function take functions as input. Our class-based implementation will make our object instance callable and accept a function.

func._fsm = TransitionDetails(
func.__name__,
self.source,
self.target,
self.conditions,
self.on_error,
)

@functools.wraps(func)
def sync_callable(*args, **kwargs):
try:
state_machine, rest = args
except ValueError:
state_machine = args[0]

if state_machine.state not in self.source:
exception_message = (
f"Current state is {state_machine.state}. "
f"{func.__name__} allows transitions from {self.source}."
)
raise InvalidStartState(exception_message)

conditions_not_met = []
for condition in self.conditions:
if condition(*args, **kwargs) is not True:
conditions_not_met.append(condition)
if conditions_not_met:
raise ConditionsNotMet(conditions_not_met)

if not self.on_error:
result = func(*args, **kwargs)
state_machine.state = self.target
return result

try:
result = func(*args, **kwargs)
state_machine.state = self.target
return result
except Exception:
# TODO should we log this somewhere?
# logger.error? maybe have an optional parameter to set this up
# how to libraries log?
state_machine.state = self.on_error
return

@functools.wraps(func)
async def async_callable(*args, **kwargs):
try:
state_machine, rest = args
except ValueError:
state_machine = args[0]

if state_machine.state not in self.source:
exception_message = (
f"Current state is {state_machine.state}. "
f"{func.__name__} allows transitions from {self.source}."
)
raise InvalidStartState(exception_message)

conditions_not_met = []
for condition in self.conditions:
if asyncio.iscoroutinefunction(condition):
condition_result = await condition(*args, **kwargs)
else:
condition_result = condition(*args, **kwargs)
if condition_result is not True:
conditions_not_met.append(condition)
if conditions_not_met:
raise ConditionsNotMet(conditions_not_met)

if not self.on_error:
result = await func(*args, **kwargs)
state_machine.state = self.target
return result

try:
result = await func(*args, **kwargs)
state_machine.state = self.target
return result
except Exception:
# TODO should we log this somewhere?
# logger.error? maybe have an optional parameter to set this up
# how to libraries log?
state_machine.state = self.on_error
return

if asyncio.iscoroutinefunction(func):
return async_callable
else:

@functools.wraps(func)
async def _wrapper(*args, **kwargs):
try:
self, rest = args
except ValueError:
self = args[0]

if self.state not in source:
exception_message = (
f"Current state is {self.state}. "
f"{func.__name__} allows transitions from {source}."
)
raise InvalidStartState(exception_message)

conditions_not_met = []
for condition in conditions:
if asyncio.iscoroutinefunction(condition):
condition_result = await condition(*args, **kwargs)
else:
condition_result = condition(*args, **kwargs)
if condition_result is not True:
conditions_not_met.append(condition)
if conditions_not_met:
raise ConditionsNotMet(conditions_not_met)

if not on_error:
result = await func(*args, **kwargs)
self.state = target
return result

try:
result = await func(*args, **kwargs)
self.state = target
return result
except Exception:
# TODO should we log this somewhere?
# logger.error? maybe have an optional parameter to set this up
# how to libraries log?
self.state = on_error
return

return _wrapper

return transition_decorator
return sync_callable