diff --git a/src/makefun/main.py b/src/makefun/main.py index 6228ec7..5241dc7 100644 --- a/src/makefun/main.py +++ b/src/makefun/main.py @@ -278,11 +278,6 @@ def create_function(func_signature, # type: Union[str, Signature] else: raise TypeError("Invalid type for `func_signature`: %s" % type(func_signature)) - if isinstance(attrs.get('__signature__'), str): - # __signature__ must be a Signature object, so if it is a string, - # we need to evaluate it. - attrs['__signature__'] = get_signature_from_string(attrs['__signature__'], evaldict)[1] - # extract all information needed from the `Signature` params_to_kw_assignment_mode = get_signature_params(func_signature) params_names = list(params_to_kw_assignment_mode.keys()) @@ -970,7 +965,17 @@ def _get_args_for_wrapping(wrapped, new_sig, remove_args, prepend_args, append_a # PEP362: always set `__wrapped__`, and if signature was changed, set `__signature__` too all_attrs["__wrapped__"] = wrapped if has_new_sig: - all_attrs["__signature__"] = func_sig + if isinstance(func_sig, Signature): + all_attrs["__signature__"] = func_sig + else: + # __signature__ must be a Signature object, so if it is a string we need to evaluate it. + frame = _get_callerframe(offset=1) + evaldict, _ = extract_module_and_evaldict(frame) + # Here we could wish to directly override `func_name` and `func_sig` so that this does not have to be done + # again by `create_function` later... Would this be risky ? + _func_name, func_sig_as_sig, _ = get_signature_from_string(func_sig, evaldict) + all_attrs["__signature__"] = func_sig_as_sig + all_attrs.update(attrs) return func_name, func_sig, doc, qualname, co_name, module_name, all_attrs diff --git a/tests/test_advanced.py b/tests/test_advanced.py index c2c8141..df39914 100644 --- a/tests/test_advanced.py +++ b/tests/test_advanced.py @@ -245,3 +245,18 @@ def foo(a): return a assert foo(10) == 10 + + +@pytest.mark.skipif(sys.version_info < (3, 5), reason="requires python 3.5 or higher (non-comment type hints)") +def test_type_hint_error_sigchange(): + """ Test for https://github.com/smarie/python-makefun/issues/32 """ + + from tests._test_py35 import make_ref_function + from typing import Any + ref_f = make_ref_function() + + @wraps(ref_f, new_sig="(a: Any)") + def foo(a): + return a + + assert foo(10) == 10