diff --git a/src/safeds_runner/__init__.py b/src/safeds_runner/__init__.py index 7a379d0..40ab35c 100644 --- a/src/safeds_runner/__init__.py +++ b/src/safeds_runner/__init__.py @@ -1,9 +1,10 @@ """A runner for the Python code generated from Safe-DS programs.""" -from .server._pipeline_manager import file_mtime, memoized_call, save_placeholder +from .server._pipeline_manager import file_mtime, memoized_dynamic_call, memoized_static_call, save_placeholder __all__ = [ "file_mtime", - "memoized_call", + "memoized_static_call", + "memoized_dynamic_call", "save_placeholder", ] diff --git a/src/safeds_runner/server/_pipeline_manager.py b/src/safeds_runner/server/_pipeline_manager.py index 02584c2..8c1ca53 100644 --- a/src/safeds_runner/server/_pipeline_manager.py +++ b/src/safeds_runner/server/_pipeline_manager.py @@ -309,7 +309,7 @@ def save_placeholder(placeholder_name: str, value: Any) -> None: current_pipeline.save_placeholder(placeholder_name, value) -def memoized_call( +def memoized_static_call( function_name: str, function_callable: typing.Callable, parameters: list[Any], @@ -342,6 +342,52 @@ def memoized_call( return memoization_map.memoized_function_call(function_name, function_callable, parameters, hidden_parameters) +def memoized_dynamic_call( + function_name: str, + function_callable: typing.Callable | None, + parameters: list[Any], + hidden_parameters: list[Any], +) -> Any: + """ + Dynamically call a function that can be memoized and save the result. + + If a function has been previously memoized, the previous result may be reused. + Dynamically calling in this context means, that if a callable is provided (e.g. if default parameters are set), it will be called. + If no such callable is provided, the function name will be used to look up the function on the instance passed as the first parameter in the parameter list. + + Parameters + ---------- + function_name : str + Simple function name + function_callable : typing.Callable | None + Function that is called and memoized if the result was not found in the memoization map or none, if the function handle should be in the provided instance + parameters : list[Any] + List of parameters for the function, the first parameter should be the instance the function should be called on (receiver) + hidden_parameters : list[Any] + List of hidden parameters for the function. This is used for memoizing some impure functions. + + Returns + ------- + Any + The result of the specified function, if any exists + """ + if current_pipeline is None: + return None # pragma: no cover + fully_qualified_function_name = ( + parameters[0].__class__.__module__ + "." + parameters[0].__class__.__qualname__ + "." + function_name + ) + memoization_map = current_pipeline.get_memoization_map() + if function_callable is None: + function_target_bound = getattr(parameters[0], function_name) + function_callable = function_target_bound.__func__ + return memoization_map.memoized_function_call( + fully_qualified_function_name, + function_callable, + parameters, + hidden_parameters, + ) + + def file_mtime(filename: str) -> int | None: """ Get the last modification timestamp of the provided file. diff --git a/tests/safeds_runner/server/test_memoization.py b/tests/safeds_runner/server/test_memoization.py index 33960ae..e10ca70 100644 --- a/tests/safeds_runner/server/test_memoization.py +++ b/tests/safeds_runner/server/test_memoization.py @@ -15,7 +15,12 @@ _make_hashable, ) from safeds_runner.server._messages import MessageDataProgram, ProgramMainInformation -from safeds_runner.server._pipeline_manager import PipelineProcess, file_mtime, memoized_call +from safeds_runner.server._pipeline_manager import ( + PipelineProcess, + file_mtime, + memoized_dynamic_call, + memoized_static_call, +) class UnhashableClass: @@ -31,7 +36,7 @@ def __hash__(self) -> int: ], ids=["function_pure", "function_impure_readfile"], ) -def test_memoization_already_present_values( +def test_memoization_static_already_present_values( function_name: str, params: list, hidden_params: list, @@ -57,7 +62,7 @@ def test_memoization_already_present_values( [], [sys.getsizeof(expected_result)], ) - result = _pipeline_manager.memoized_call(function_name, lambda *_: None, params, hidden_params) + result = _pipeline_manager.memoized_static_call(function_name, lambda *_: None, params, hidden_params) assert result == expected_result @@ -71,7 +76,7 @@ def test_memoization_already_present_values( ], ids=["function_pure", "function_impure_readfile", "function_dict", "function_lambda"], ) -def test_memoization_not_present_values( +def test_memoization_static_not_present_values( function_name: str, function: typing.Callable, params: list, @@ -86,13 +91,146 @@ def test_memoization_not_present_values( MemoizationMap({}, {}), ) # Save value in map - result = memoized_call(function_name, function, params, hidden_params) + result = memoized_static_call(function_name, function, params, hidden_params) + assert result == expected_result + # Test if value is actually saved by calling another function that does not return the expected result + result2 = memoized_static_call(function_name, lambda *_: None, params, hidden_params) + assert result2 == expected_result + + +class BaseClass: + def __init__(self) -> None: + pass + + def method1(self) -> int: + return 1 + + def method2(self, default: int = 5) -> int: + return 1 * default + + +class ChildClass(BaseClass): + def __init__(self) -> None: + super().__init__() + + def method1(self) -> int: + return 2 + + def method2(self, default: int = 3) -> int: + return 2 * default + + +@pytest.mark.parametrize( + argnames="function_name,function,params,hidden_params,expected_result", + argvalues=[ + ("method1", None, [BaseClass()], [], 1), + ("method1", None, [ChildClass()], [], 2), + ("method2", lambda instance, *_: instance.method2(default=7), [BaseClass(), 7], [], 7), + ("method2", lambda instance, *_: instance.method2(default=7), [ChildClass(), 7], [], 14), + ], + ids=["member_call_base", "member_call_child", "member_call_base_lambda", "member_call_child_lambda"], +) +def test_memoization_dynamic( + function_name: str, + function: typing.Callable | None, + params: list, + hidden_params: list, + expected_result: Any, +) -> None: + _pipeline_manager.current_pipeline = PipelineProcess( + MessageDataProgram({}, ProgramMainInformation("", "", "")), + "", + Queue(), + {}, + MemoizationMap({}, {}), + ) + # Save value in map + result = memoized_dynamic_call(function_name, function, params, hidden_params) assert result == expected_result # Test if value is actually saved by calling another function that does not return the expected result - result2 = memoized_call(function_name, lambda *_: None, params, hidden_params) + result2 = memoized_dynamic_call(function_name, lambda *_: None, params, hidden_params) assert result2 == expected_result +@pytest.mark.parametrize( + argnames="function_name,function,params,hidden_params,fully_qualified_function_name", + argvalues=[ + ("method1", None, [BaseClass()], [], "tests.safeds_runner.server.test_memoization.BaseClass.method1"), + ("method1", None, [ChildClass()], [], "tests.safeds_runner.server.test_memoization.ChildClass.method1"), + ( + "method2", + lambda instance, *_: instance.method2(default=7), + [BaseClass(), 7], + [], + "tests.safeds_runner.server.test_memoization.BaseClass.method2", + ), + ( + "method2", + lambda instance, *_: instance.method2(default=7), + [ChildClass(), 7], + [], + "tests.safeds_runner.server.test_memoization.ChildClass.method2", + ), + ], + ids=["member_call_base", "member_call_child", "member_call_base_lambda", "member_call_child_lambda"], +) +def test_memoization_dynamic_contains_correct_fully_qualified_name( + function_name: str, + function: typing.Callable | None, + params: list, + hidden_params: list, + fully_qualified_function_name: Any, +) -> None: + _pipeline_manager.current_pipeline = PipelineProcess( + MessageDataProgram({}, ProgramMainInformation("", "", "")), + "", + Queue(), + {}, + MemoizationMap({}, {}), + ) + # Save value in map + result = memoized_dynamic_call(function_name, function, params, hidden_params) + # Test if value is actually saved with the correct function name + result2 = memoized_static_call(fully_qualified_function_name, lambda *_: None, params, hidden_params) + assert result == result2 + + +@pytest.mark.parametrize( + argnames="function_name,function,params,hidden_params,fully_qualified_function_name", + argvalues=[ + ("method1", None, [ChildClass()], [], "tests.safeds_runner.server.test_memoization.BaseClass.method1"), + ( + "method2", + lambda instance, *_: instance.method2(default=7), + [ChildClass(), 7], + [], + "tests.safeds_runner.server.test_memoization.BaseClass.method2", + ), + ], + ids=["member_call_child", "member_call_child_lambda"], +) +def test_memoization_dynamic_not_base_name( + function_name: str, + function: typing.Callable | None, + params: list, + hidden_params: list, + fully_qualified_function_name: Any, +) -> None: + _pipeline_manager.current_pipeline = PipelineProcess( + MessageDataProgram({}, ProgramMainInformation("", "", "")), + "", + Queue(), + {}, + MemoizationMap({}, {}), + ) + # Save value in map + result = memoized_dynamic_call(function_name, function, params, hidden_params) + # Test if value is actually saved with the correct function name + result2 = memoized_static_call(fully_qualified_function_name, lambda *_: None, params, hidden_params) + assert result is not None + assert result2 is None + + @pytest.mark.parametrize( argnames="function_name,function,params,hidden_params,expected_result", argvalues=[ @@ -101,7 +239,7 @@ def test_memoization_not_present_values( ], ids=["unhashable_params", "unhashable_hidden_params"], ) -def test_memoization_unhashable_values( +def test_memoization_static_unhashable_values( function_name: str, function: typing.Callable, params: list, @@ -115,8 +253,7 @@ def test_memoization_unhashable_values( {}, MemoizationMap({}, {}), ) - - result = memoized_call(function_name, function, params, hidden_params) + result = memoized_static_call(function_name, function, params, hidden_params) assert result == expected_result