Skip to content

Commit

Permalink
Fixed bug when the signature of the function to create contains non-l…
Browse files Browse the repository at this point in the history
…ocally available type hints. Fixes #32
  • Loading branch information
Sylvain MARIE committed Mar 26, 2019
1 parent 17fd618 commit 6615e31
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 17 deletions.
73 changes: 57 additions & 16 deletions makefun/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,12 @@ def _is_generator_func(func_impl):
return isgeneratorfunction(func_impl)


class DefaultHolder:
class _SymbolRef:
"""
A class used to protect signature default values and type hints when the local context would not be able
to evaluate them properly when the new function is created. In this case we store them under a known name,
we add that name to the locals(), and we use this symbol that has a repr() equal to the name.
"""
__slots__ = 'varname'

def __init__(self, varname):
Expand All @@ -267,22 +272,16 @@ def get_signature_string(func_name, func_signature, evaldict):
# protect the parameters if needed
new_params = []
for p_name, p in func_signature.parameters.items():
if p.default is not Parameter.empty and not isinstance(p.default, (int, str, float, bool)):
# check if the repr() of the default value is equal to itself.
needs_protection = True
try:
deflt = eval(repr(p.default))
needs_protection = deflt != p.default
except SyntaxError:
pass

# if we have any problem, we need to protect the default value
if needs_protection:
# store the object in the evaldict and insert name
varname = "DEFAULT_%s" % p_name
evaldict[varname] = p.default
p = Parameter(p.name, kind=p.kind, default=DefaultHolder(varname), annotation=p.annotation)
# if default value can not be evaluated, protect it
default_needs_protection = _signature_symbol_needs_protection(p.default, evaldict)
new_default = _protect_signature_symbol(p.default, default_needs_protection, "DEFAULT_%s" % p_name, evaldict)

# if type hint can not be evaluated, protect it
annotation_needs_protection = _signature_symbol_needs_protection(p.annotation, evaldict)
new_annotation = _protect_signature_symbol(p.annotation, annotation_needs_protection, "HINT_%s" % p_name, evaldict)

# replace the parameter with the possibly new default and hint
p = Parameter(p.name, kind=p.kind, default=new_default, annotation=new_annotation)
new_params.append(p)

# copy signature object
Expand All @@ -292,6 +291,48 @@ def get_signature_string(func_name, func_signature, evaldict):
return "%s%s:" % (func_name, s)


def _signature_symbol_needs_protection(symbol, evaldict):
"""
Helper method for signature symbols (defaults, type hints) protection.
Returns True if the given symbol needs to be protected - that is, if its repr() can not be correctly evaluated with current evaldict.
:param symbol:
:return:
"""
if symbol is not None and symbol is not Parameter.empty and not isinstance(symbol, (int, str, float, bool)):
# check if the repr() of the default value is equal to itself.
try:
deflt = eval(repr(symbol), evaldict)
needs_protection = deflt != symbol
except SyntaxError:
needs_protection = True
else:
needs_protection = False

return needs_protection


def _protect_signature_symbol(val, needs_protection, varname, evaldict):
"""
Helper method for signature symbols (defaults, type hints) protection.
Returns either `val`, or a protection symbol. In that case the protection symbol
is created with name `varname` and inserted into `evaldict`
:param val:
:param needs_protection:
:param varname:
:param evaldict:
:return:
"""
if needs_protection:
# store the object in the evaldict and insert name
evaldict[varname] = val
return _SymbolRef(varname)
else:
return val


def get_signature_from_string(func_sig_str, evaldict):
"""
Creates a `Signature` object from the given function signature string.
Expand Down
13 changes: 13 additions & 0 deletions makefun/tests/_test_py35.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,16 @@ async def my_native_coroutine_handler(sleep_time):
return sleep_time

return my_native_coroutine_handler


def make_ref_function():
"""Returns a function with a type hint that is locally defined """

# the symbol is defined here, so it is not seen outside
class A:
pass

def ref(a: A):
pass

return ref
15 changes: 14 additions & 1 deletion makefun/tests/test_advanced.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import sys
from copy import copy, deepcopy

import pytest

Expand Down Expand Up @@ -128,3 +127,17 @@ def g(self):
# our mod
assert C.D.g.__qualname__ == 'test_qualname_when_nested.<locals>.C.D.g'
assert str(signature(C.D.g)) == "(self, a)"


@pytest.mark.skipif(sys.version_info < (3, 5), reason="requires python 3.5 or higher (non-comment type hints)")
def test_type_hint_error():
""" Test for https://github.com/smarie/python-makefun/issues/32 """

from makefun.tests._test_py35 import make_ref_function
ref_f = make_ref_function()

@wraps(ref_f)
def foo(a):
return a

assert foo(10) == 10

0 comments on commit 6615e31

Please sign in to comment.