diff --git a/.gitignore b/.gitignore index 0836d14982..3db75bdf3f 100644 --- a/.gitignore +++ b/.gitignore @@ -29,5 +29,4 @@ src/*.egg-info .venv cov.xml .coverage.* -*.psycache __pycache__ diff --git a/doc/Makefile b/doc/Makefile new file mode 100644 index 0000000000..b4cdc968b6 --- /dev/null +++ b/doc/Makefile @@ -0,0 +1,18 @@ + +all: developer_guide reference_guide user_guide +.PHONY: developer_guide reference_guide user_guide + +developer_guide: + make -C developer_guide html SPHINXOPTS="-W --keep-going" + +reference_guide: + make -C reference_guide html SPHINXOPTS="-W --keep-going" + +user_guide: + make -C user_guide html SPHINXOPTS="-W --keep-going" + +clean: + make -C developer_guide clean + make -C reference_guide allclean + make -C user_guide clean + diff --git a/src/psyclone/psyir/nodes/__init__.py b/src/psyclone/psyir/nodes/__init__.py index ddfd927e91..762c451e6e 100644 --- a/src/psyclone/psyir/nodes/__init__.py +++ b/src/psyclone/psyir/nodes/__init__.py @@ -195,5 +195,5 @@ 'OMPScheduleClause', 'OMPFirstprivateClause', 'OMPSharedClause', - 'OMPDependClause' - ] + 'OMPDependClause', +] diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index f49ce59577..75e65d057b 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -38,32 +38,13 @@ from collections.abc import Iterable -from psyclone.configuration import Config from psyclone.core import AccessType from psyclone.errors import GenerationError -from psyclone.psyir.nodes.container import Container from psyclone.psyir.nodes.statement import Statement from psyclone.psyir.nodes.datanode import DataNode from psyclone.psyir.nodes.reference import Reference -from psyclone.psyir.nodes.routine import Routine -from psyclone.psyir.symbols import ( - RoutineSymbol, - Symbol, - SymbolError, - UnsupportedFortranType, - DataSymbol, -) +from psyclone.psyir.symbols import RoutineSymbol from typing import List -from psyclone.errors import PSycloneError - - -class CallMatchingArgumentsNotFound(PSycloneError): - '''Exception to signal that matching arguments have not been found - for this routine - ''' - def __init__(self, value): - PSycloneError.__init__(self, value) - self.value = "CallMatchingArgumentsNotFound: " + str(value) class Call(Statement, DataNode): @@ -456,11 +437,18 @@ def copy(self): return new_copy - def get_callees(self): + def get_callees(self, ignore_missing_modules: bool = False): ''' Searches for the implementation(s) of all potential target routines for this Call without any arguments check. + Deprecation warning: This only exists for backwards compatibility + reason. It's recommende to directly use `CallRoutineMatcher`. + + :param ignore_missing_modules: If a module wasn't found, return 'None' + instead of throwing an exception 'ModuleNotFound'. + :type ignore_missing_modules: bool + :returns: the Routine(s) that this call targets. :rtype: list[:py:class:`psyclone.psyir.nodes.Routine`] @@ -468,277 +456,23 @@ def get_callees(self): in any containers in scope at the call site. ''' - def _location_txt(node): - ''' - Utility to generate meaningful location text. - - :param node: a PSyIR node. - :type node: :py:class:`psyclone.psyir.nodes.Node` - - :returns: description of location of node. - :rtype: str - ''' - if isinstance(node, Container): - return f"Container '{node.name}'" - out_lines = node.debug_string().split("\n") - idx = -1 - while not out_lines[idx]: - idx -= 1 - last_line = out_lines[idx] - return f"code:\n'{out_lines[0]}\n...\n{last_line}'" - - rsym = self.routine.symbol - if rsym.is_unresolved: - - # Check for any "raw" Routines, i.e. ones that are not - # in a Container. Such Routines would exist in the PSyIR - # as a child of a FileContainer (if the PSyIR contains a - # FileContainer). Note, if the PSyIR does contain a - # FileContainer, it will be the root node of the PSyIR. - for routine in self.root.children: - if (isinstance(routine, Routine) and - routine.name.lower() == rsym.name.lower()): - return [routine] - - # Now check for any wildcard imports and see if they can - # be used to resolve the symbol. - wildcard_names = [] - containers_not_found = [] - current_table = self.scope.symbol_table - while current_table: - for container_symbol in current_table.containersymbols: - if container_symbol.wildcard_import: - wildcard_names.append(container_symbol.name) - try: - container = container_symbol.find_container_psyir( - local_node=self) - except SymbolError: - container = None - if not container: - # Failed to find/process this Container. - containers_not_found.append(container_symbol.name) - continue - routines = [] - for name in container.resolve_routine(rsym.name): - psyir = container.find_routine_psyir(name) - if psyir: - routines.append(psyir) - if routines: - return routines - current_table = current_table.parent_symbol_table() - if not wildcard_names: - wc_text = "there are no wildcard imports" - else: - if containers_not_found: - wc_text = ( - f"attempted to resolve the wildcard imports from" - f" {wildcard_names}. However, failed to find the " - f"source for {containers_not_found}. The module search" - f" path is set to {Config.get().include_paths}") - else: - wc_text = (f"wildcard imports from {wildcard_names}") - raise NotImplementedError( - f"Failed to find the source code of the unresolved routine " - f"'{rsym.name}' - looked at any routines in the same source " - f"file and {wc_text}. Searching for external routines " - f"that are only resolved at link time is not supported.") - - root_node = self.ancestor(Container) - if not root_node: - root_node = self.root - container = root_node - can_be_private = True - - if rsym.is_import: - cursor = rsym - # A Routine imported from another Container must be public in that - # Container. - can_be_private = False - while cursor.is_import: - csym = cursor.interface.container_symbol - try: - container = csym.find_container_psyir(local_node=self) - except SymbolError: - raise NotImplementedError( - f"RoutineSymbol '{rsym.name}' is imported from " - f"Container '{csym.name}' but the source defining " - f"that container could not be found. The module search" - f" path is set to {Config.get().include_paths}") - imported_sym = container.symbol_table.lookup(cursor.name) - if imported_sym.visibility != Symbol.Visibility.PUBLIC: - # The required Symbol must be shadowed with a PRIVATE - # Symbol in this Container. This means that the one we - # actually want is brought into scope via a wildcard - # import. - # TODO #924 - Use ModuleManager to search? - raise NotImplementedError( - f"RoutineSymbol '{rsym.name}' is imported from " - f"Container '{csym.name}' but that Container defines " - f"a private Symbol of the same name. Searching for the" - f" Container that defines a public Routine with that " - f"name is not yet supported - TODO #924") - if not isinstance(imported_sym, RoutineSymbol): - # We now know that this is a RoutineSymbol so specialise it - # in place. - imported_sym.specialise(RoutineSymbol) - cursor = imported_sym - rsym = cursor - root_node = container - - if isinstance(rsym.datatype, UnsupportedFortranType): - # TODO #924 - an UnsupportedFortranType here typically indicates - # that the target is actually an interface. - raise NotImplementedError( - f"RoutineSymbol '{rsym.name}' exists in " - f"{_location_txt(root_node)} but is of " - f"UnsupportedFortranType:\n{rsym.datatype.declaration}\n" - f"Cannot get the PSyIR of such a routine.") - - if isinstance(container, Container): - routines = [] - for name in container.resolve_routine(rsym.name): - psyir = container.find_routine_psyir( - name, allow_private=can_be_private) - if psyir: - routines.append(psyir) - if routines: - return routines - - raise SymbolError( - f"Failed to find a Routine named '{rsym.name}' in " - f"{_location_txt(root_node)}. This is normally because the routine" - f" is within a CodeBlock.") - - def _check_argument_type_matches( - self, - call_arg: DataSymbol, - routine_arg: DataSymbol, - ) -> None: - """Return information whether argument types are matching. - This also supports 'optional' arguments by using - partial types. - - :param call_arg: One argument of the call - :param routine_arg: One argument of the routine - - :raises CallMatchingArgumentsNotFound: Raised if no matching argument - was found. - - """ - if isinstance( - routine_arg.datatype, UnsupportedFortranType - ): - # This could be an 'optional' argument. - # This has at least a partial data type - if ( - call_arg.datatype - != routine_arg.datatype.partial_datatype - ): - raise CallMatchingArgumentsNotFound( - f"Argument partial type mismatch of call " - f"argument '{call_arg}' and routine argument " - f"'{routine_arg}'" - ) - else: - if call_arg.datatype != routine_arg.datatype: - raise CallMatchingArgumentsNotFound( - f"Argument type mismatch of call argument " - f"'{call_arg}' and routine argument " - f"'{routine_arg}'" - ) - - def _get_argument_routine_match(self, routine: Routine): - '''Return a list of integers giving for each argument of the call - the index of the corresponding entry in the argument list of the - supplied routine. - - :return: None if no match was found, otherwise list of integers - referring to matching arguments. - :rtype: None|List[int] - ''' - - # Create a copy of the list of actual arguments to the routine. - # Once an argument has been successfully matched, set it to 'None' - routine_argument_list: List[DataSymbol] = ( - routine.symbol_table.argument_list[:] + from psyclone.psyir.tools import ( + CallRoutineMatcher ) - if len(self.arguments) > len(routine.symbol_table.argument_list): - raise CallMatchingArgumentsNotFound( - f"More arguments in call ('{self.debug_string()}')" - f" than callee (routine '{routine.name}')" - ) - - # Iterate over all arguments to the call - ret_arg_idx_list = [] - for call_arg_idx, call_arg in enumerate(self.arguments): - call_arg_idx: int - call_arg: DataSymbol - - # If the associated name is None, it's a positional argument - # => Just return the index if the types match - if self.argument_names[call_arg_idx] is None: - routine_arg = routine_argument_list[call_arg_idx] - routine_arg: DataSymbol - - self._check_argument_type_matches(call_arg, routine_arg) - - ret_arg_idx_list.append(call_arg_idx) - routine_argument_list[call_arg_idx] = None - continue - - # - # Next, we handle all named arguments - # - arg_name = self.argument_names[call_arg_idx] - routine_arg_idx = None - - for routine_arg_idx, routine_arg in enumerate( - routine_argument_list - ): - routine_arg: DataSymbol - - # Check if argument was already processed - if routine_arg is None: - continue - - if arg_name == routine_arg.name: - self._check_argument_type_matches(call_arg, routine_arg) - ret_arg_idx_list.append(routine_arg_idx) - break + call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher(self) + call_routine_matcher.set_option( + ignore_missing_modules=ignore_missing_modules, + ) - else: - # It doesn't match => Raise exception - raise CallMatchingArgumentsNotFound( - f"Named argument '{arg_name}' not found for routine" - f" '{routine.name}' in call '{self.debug_string()}'" - ) - - routine_argument_list[routine_arg_idx] = None - - # - # Finally, we check if all left-over arguments are optional arguments - # - for routine_arg in routine_argument_list: - routine_arg: DataSymbol - - if routine_arg is None: - continue - - # TODO #759: Optional keyword is not yet supported in psyir. - # Hence, we use a simple string match. - if ", OPTIONAL" not in str(routine_arg.datatype): - raise CallMatchingArgumentsNotFound( - f"Argument '{routine_arg.name}' in subroutine" - f" '{routine.name}' does not match any in the call" - f" '{self.debug_string()}' and is not OPTIONAL." - ) - - return ret_arg_idx_list + return call_routine_matcher.get_callee_candidates() def get_callee( self, check_matching_arguments: bool = True, + check_strict_array_datatype: bool = True, + ignore_missing_modules: bool = False, + ignore_unresolved_symbol: bool = False, ): ''' Searches for the implementation(s) of the target routine for this Call @@ -765,33 +499,17 @@ def get_callee( in any containers in scope at the call site. ''' - routine_list = self.get_callees() - - err_info_list = [] - - # Search for the routine matching the right arguments - for routine in routine_list: - routine: Routine - - try: - arg_match_list = self._get_argument_routine_match(routine) - except CallMatchingArgumentsNotFound as err: - err_info_list.append(err.value) - continue - - return (routine, arg_match_list) - - # If we didn't find any routine, return some routine if no matching - # arguments have been found. - # This is handy for the transition phase until optional argument - # matching is supported. - if not check_matching_arguments: - # Also return a list of dummy argument indices - return list(range(len(self.arguments))) + from psyclone.psyir.tools.call_routine_matcher import ( + CallRoutineMatcher + ) - error_msg = "\n".join(err_info_list) + call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher(self) + call_routine_matcher.set_option( + check_matching_arguments=check_matching_arguments, + check_strict_array_datatype=( + check_strict_array_datatype), + ignore_missing_modules=ignore_missing_modules, + ignore_unresolved_types=ignore_unresolved_symbol, + ) - raise CallMatchingArgumentsNotFound( - f"No matching routine found for '{self.debug_string()}':" - "\n" + error_msg - ) + return call_routine_matcher.get_callee() diff --git a/src/psyclone/psyir/symbols/containersymbol.py b/src/psyclone/psyir/symbols/containersymbol.py index 45830b9abb..43122c2f71 100644 --- a/src/psyclone/psyir/symbols/containersymbol.py +++ b/src/psyclone/psyir/symbols/containersymbol.py @@ -135,6 +135,10 @@ def find_container_psyir(self, local_node=None): the container. :type local_node: Optional[:py:class:`psyclone.psyir.nodes.Node`] + :param ignore_missing_modules: If 'True', no ModuleNotFound exception= + is raised in case in case the module wasn't found. + :type ignore_missing_modules: bool + :returns: referenced container. :rtype: :py:class:`psyclone.psyir.nodes.Container` diff --git a/src/psyclone/psyir/symbols/symbol_table.py b/src/psyclone/psyir/symbols/symbol_table.py index d3e09158e4..62a7ddc7de 100644 --- a/src/psyclone/psyir/symbols/symbol_table.py +++ b/src/psyclone/psyir/symbols/symbol_table.py @@ -585,8 +585,11 @@ def add(self, new_symbol, tag=None): self._symbols[key] = new_symbol - def check_for_clashes(self, other_table, symbols_to_skip=()): - ''' + def check_for_clashes( + self, other_table, symbols_to_skip=(), + check_unresolved_symbols: bool = True + ): + """ Checks the symbols in the supplied table against those in this table. If there is a name clash that cannot be resolved by renaming then a SymbolError is raised. Any symbols appearing @@ -598,13 +601,16 @@ def check_for_clashes(self, other_table, symbols_to_skip=()): the check. :type symbols_to_skip: Iterable[ :py:class:`psyclone.psyir.symbols.Symbol`] + :param check_unresolved_symbols: If 'True', also check unresolved + symbols + :type check_unresolved_symbols: bool :raises TypeError: if symbols_to_skip is supplied but is not an instance of Iterable. :raises SymbolError: if there would be an unresolvable name clash when importing symbols from `other_table` into this table. - ''' + """ # pylint: disable-next=import-outside-toplevel from psyclone.psyir.nodes import IntrinsicCall @@ -648,6 +654,10 @@ def check_for_clashes(self, other_table, symbols_to_skip=()): f"table imports it via '{other_sym.interface}'.") continue + if not check_unresolved_symbols: + # Skip if unresolved symbols shouldn't be checked + continue + if other_sym.is_unresolved and this_sym.is_unresolved: # Both Symbols are unresolved. if shared_wildcard_imports and not unique_wildcard_imports: @@ -822,7 +832,8 @@ def _handle_symbol_clash(self, old_sym, other_table): self.rename_symbol(self_sym, new_name) self.add(old_sym) - def merge(self, other_table, symbols_to_skip=()): + def merge(self, other_table, symbols_to_skip=(), + check_unresolved_symbols: bool = True): '''Merges all of the symbols found in `other_table` into this table. Symbol objects in *either* table may be renamed in the event of clashes. @@ -835,6 +846,9 @@ def merge(self, other_table, symbols_to_skip=()): the merge. :type symbols_to_skip: Iterable[ :py:class:`psyclone.psyir.symbols.Symbol`] + :param check_unresolved_symbols: If `True`, also check unresolved + symbols. + :type check_unresolved_symbols: bool :raises TypeError: if `other_table` is not a SymbolTable. :raises TypeError: if `symbols_to_skip` is not an Iterable. @@ -851,7 +865,9 @@ def merge(self, other_table, symbols_to_skip=()): try: self.check_for_clashes(other_table, - symbols_to_skip=symbols_to_skip) + symbols_to_skip=symbols_to_skip, + check_unresolved_symbols=( + check_unresolved_symbols)) except SymbolError as err: raise SymbolError( f"Cannot merge {other_table.view()} with {self.view()} due to " diff --git a/src/psyclone/psyir/tools/__init__.py b/src/psyclone/psyir/tools/__init__.py index ca77a226a4..32706db9c1 100644 --- a/src/psyclone/psyir/tools/__init__.py +++ b/src/psyclone/psyir/tools/__init__.py @@ -40,10 +40,14 @@ from psyclone.psyir.tools.dependency_tools import DTCode, DependencyTools from psyclone.psyir.tools.read_write_info import ReadWriteInfo from psyclone.psyir.tools.definition_use_chains import DefinitionUseChain +from psyclone.psyir.tools.call_routine_matcher import CallRoutineMatcher, CallMatchingArgumentsNotFoundError # For AutoAPI documentation generation. __all__ = ['CallTreeUtils', 'DTCode', 'DependencyTools', 'DefinitionUseChain', - 'ReadWriteInfo'] + 'ReadWriteInfo', + 'CallRoutineMatcher', + 'CallMatchingArgumentsNotFoundError', + ] diff --git a/src/psyclone/psyir/tools/call_routine_matcher.py b/src/psyclone/psyir/tools/call_routine_matcher.py new file mode 100644 index 0000000000..d344278eb9 --- /dev/null +++ b/src/psyclone/psyir/tools/call_routine_matcher.py @@ -0,0 +1,530 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2024, Science and Technology Facilities Council and +# University Grenoble Alpes +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# This file is based on gathering various components related to +# calls and routines from across psyclone. Hence, there's no clear author. +# Authors of gathered files: R. W. Ford, A. R. Porter and +# S. Siso, STFC Daresbury Lab +# Creator/partial author of this file: M. Schreiber, University Grenoble Alpes +# ----------------------------------------------------------------------------- + +from typing import List, Union, Set +from psyclone.psyir.symbols.datatypes import ArrayType, UnresolvedType +from psyclone.errors import PSycloneError +from psyclone.psyir.nodes import Call, Routine +from psyclone.psyir.nodes.container import Container +from psyclone.psyir.symbols import ( + RoutineSymbol, + Symbol, + SymbolError, + UnsupportedFortranType, + DataSymbol, + SymbolTable, + ContainerSymbol, +) +from psyclone.configuration import Config + + +class CallMatchingArgumentsNotFoundError(PSycloneError): + """Exception to signal that matching arguments have not been found + for this routine + + """ + + def __init__(self, value): + PSycloneError.__init__(self, value) + self.value = "CallMatchingArgumentsNotFound: " + str(value) + + +class CallRoutineMatcher: + """Helper routines to help matching 'Call' and 'Routines'. + This includes, e.g., + - searching for matching 'Routines', + - argument matching + """ + + def __init__(self, call_node: Call = None, routine_node: Routine = None): + + # Psyir node of Call + self._call_node: Call = call_node + + # Psyir node of Routine + self._routine_node: Call = routine_node + + # List of indices relating each argument of call to one argument + # of routine. This is required to support optional arguments. + self._arg_match_list: List[int] = None + + # Also check argument types to match. + # If set to `False` and in case it doesn't find matching arguments, + # the very first implementation of the matching routine will be + # returned (even if the argument type check failed). The argument + # types and number of arguments might therefore mismatch! + self._option_check_matching_arguments: bool = True + + # Use strict array datatype checks for matching + self._option_check_strict_array_datatype: bool = True + + # If 'True', missing modules don't raise an Exception + self._option_ignore_missing_modules: bool = False + + # If 'True', unresolved types don't raise an Exception + self._option_ignore_unresolved_types: bool = False + + def set_call_node(self, call_node: Call): + self._call_node = call_node + + def set_routine_node(self, routine_node: Routine): + self._routine_node = routine_node + + def set_option(self, + check_matching_arguments: bool = None, + check_strict_array_datatype: bool = None, + ignore_missing_modules: bool = None, + ignore_unresolved_types: bool = None, + ): + + if check_matching_arguments is not None: + self._option_check_matching_arguments = check_matching_arguments + + if check_strict_array_datatype is not None: + self._option_check_strict_array_datatype = ( + check_strict_array_datatype) + + if ignore_missing_modules is not None: + self._option_ignore_missing_modules = ignore_missing_modules + + if ignore_unresolved_types is not None: + self._option_ignore_unresolved_types = ignore_unresolved_types + + def _check_argument_type_matches( + self, + call_arg: DataSymbol, + routine_arg: DataSymbol + ) -> bool: + """Return information whether argument types are matching. + This also supports 'optional' arguments by using + partial types. + + :param call_arg: Argument from the call + :type call_arg: DataSymbol + :param routine_arg: Argument from the routine + :type routine_arg: DataSymbol + :param check_strict_array_datatype: Check strictly for matching + array types. If `False`, only checks for ArrayType itself are done. + :type check_strict_array_datatype: bool + :returns: True if arguments match, False otherwise + :rtype: bool + :raises CallMatchingArgumentsNotFound: Raised if no matching arguments + were found. + """ + + if self._option_check_strict_array_datatype: + # No strict array checks have to be performed, just accept it + if isinstance(call_arg.datatype, ArrayType) and isinstance( + routine_arg.datatype, ArrayType + ): + return True + + if self._option_ignore_unresolved_types: + if isinstance(call_arg.datatype, UnresolvedType): + return True + + if isinstance(routine_arg.datatype, UnsupportedFortranType): + # This could be an 'optional' argument. + # This has at least a partial data type + if call_arg.datatype != routine_arg.datatype.partial_datatype: + raise CallMatchingArgumentsNotFoundError( + "Argument partial type mismatch of call " + f"argument '{call_arg}' and routine argument " + f"'{routine_arg}'" + ) + + return True + + if (isinstance(routine_arg.datatype, ArrayType) and + isinstance(call_arg.datatype, ArrayType)): + + # If these are two arrays, only make sure that the types + # match. + if (call_arg.datatype.datatype == routine_arg.datatype.datatype): + return True + + if call_arg.datatype != routine_arg.datatype: + raise CallMatchingArgumentsNotFoundError( + "Argument type mismatch of call argument " + f"'{call_arg}' with type '{call_arg.datatype}' " + "and routine argument " + f"'{routine_arg}' with type '{routine_arg.datatype}'." + ) + + return True + + def get_argument_routine_match_list( + self, + ) -> Union[None, List[int]]: + '''Return a list of integers giving for each argument of the call + the index of the argument in argument_list (typically of a routine) + + :return: None if no match was found, otherwise list of integers + referring to matching arguments. + :rtype: None|List[int] + :raises CallMatchingArgumentsNotFound: If there was some problem in + finding matching arguments. + ''' + + # Create a copy of the list of actual arguments to the routine. + # Once an argument has been successfully matched, set it to 'None' + routine_argument_list: List[DataSymbol] = ( + self._routine_node.symbol_table.argument_list[:] + ) + + if len(self._call_node.arguments) > len( + self._routine_node.symbol_table.argument_list): + call_str = self._call_node.debug_string().replace("\n", "") + raise CallMatchingArgumentsNotFoundError( + f"More arguments in call ('{call_str}')" + f" than callee (routine '{self._routine_node.name}')" + ) + + # Iterate over all arguments to the call + ret_arg_idx_list = [] + for call_arg_idx, call_arg in enumerate(self._call_node.arguments): + call_arg_idx: int + call_arg: DataSymbol + + # If the associated name is None, it's a positional argument + # => Just return the index if the types match + if self._call_node.argument_names[call_arg_idx] is None: + routine_arg = routine_argument_list[call_arg_idx] + routine_arg: DataSymbol + + self._check_argument_type_matches( + call_arg, routine_arg + ) + + ret_arg_idx_list.append(call_arg_idx) + routine_argument_list[call_arg_idx] = None + continue + + # + # Next, we handle all named arguments + # + arg_name = self._call_node.argument_names[call_arg_idx] + routine_arg_idx = None + + for routine_arg_idx, routine_arg in enumerate( + routine_argument_list + ): + routine_arg: DataSymbol + + # Check if argument was already processed + if routine_arg is None: + continue + + if arg_name.lower() == routine_arg.name.lower(): + self._check_argument_type_matches( + call_arg, + routine_arg + ) + ret_arg_idx_list.append(routine_arg_idx) + break + + else: + # It doesn't match => Raise exception + raise CallMatchingArgumentsNotFoundError( + f"Named argument '{arg_name}' not found." + ) + + routine_argument_list[routine_arg_idx] = None + + # + # Finally, we check if all left-over arguments are optional arguments + # + for routine_arg in routine_argument_list: + routine_arg: DataSymbol + + if routine_arg is None: + continue + + # TODO #759: Optional keyword is not yet supported in psyir. + # Hence, we use a simple string match. + if ", OPTIONAL" in str(routine_arg.datatype): + continue + + raise CallMatchingArgumentsNotFoundError( + f"Argument '{routine_arg.name}' in subroutine" + f" '{self._routine_node.name}' not handled." + ) + + return ret_arg_idx_list + + def get_callee_candidates(self) -> List[Routine]: + ''' + Searches for the implementation(s) of all potential target routines + for this Call without any arguments check. + + :returns: the Routine(s) that this call targets. + + :raises NotImplementedError: if the routine is not local and not found + in any containers in scope at the call site. + + ''' + + def _location_txt(node): + ''' + Utility to generate meaningful location text. + + :param node: a PSyIR node. + :type node: :py:class:`psyclone.psyir.nodes.Node` + + :returns: description of location of node. + :rtype: str + ''' + if isinstance(node, Container): + return f"Container '{node.name}'" + out_lines = node.debug_string().split("\n") + idx = -1 + while not out_lines[idx]: + idx -= 1 + last_line = out_lines[idx] + return f"code:\n'{out_lines[0]}\n...\n{last_line}'" + + rsym = self._call_node.routine.symbol + if rsym.is_unresolved: + + # Check for any "raw" Routines, i.e. ones that are not + # in a Container. Such Routines would exist in the PSyIR + # as a child of a FileContainer (if the PSyIR contains a + # FileContainer). Note, if the PSyIR does contain a + # FileContainer, it will be the root node of the PSyIR. + for routine in self._call_node.root.children: + if ( + isinstance(routine, Routine) + and routine.name.lower() == rsym.name.lower() + ): + return [routine] + + # Now check for any wildcard imports and see if they can + # be used to resolve the symbol. + wildcard_names = [] + containers_not_found = [] + current_table: SymbolTable = self._call_node.scope.symbol_table + while current_table: + # TODO: Obtaining all container symbols in this way + # breaks some tests. + # It would be better using the ModuleManager to resolve + # (and cache) all containers to look up for this. + # + # current_containersymbols = + # self._call_node._get_container_symbols_rec( + # current_table.containersymbols, + # ignore_missing_modules=ignore_missing_modules, + # ) + # for container_symbol in current_containersymbols: + for container_symbol in current_table.containersymbols: + container_symbol: ContainerSymbol + if container_symbol.wildcard_import: + wildcard_names.append(container_symbol.name) + + try: + container: Container = ( + container_symbol.find_container_psyir( + local_node=self._call_node, + ) + ) + except SymbolError: + container = None + if not container: + # Failed to find/process this Container. + containers_not_found.append(container_symbol.name) + continue + routines = [] + for name in container.resolve_routine(rsym.name): + # Allow private imports if an 'interface' + # was used. Here, we assume the name of the routine + # is different to the call. + allow_private = name != rsym.name + psyir = container.find_routine_psyir( + name, allow_private=allow_private + ) + + if psyir: + routines.append(psyir) + + if routines: + return routines + current_table = current_table.parent_symbol_table() + + if not wildcard_names: + wc_text = "there are no wildcard imports" + else: + if containers_not_found: + wc_text = ( + "attempted to resolve the wildcard imports from" + f" {wildcard_names}. However, failed to find the " + f"source for {containers_not_found}. The module search" + f" path is set to {Config.get().include_paths}" + ) + else: + wc_text = f"wildcard imports from {wildcard_names}" + raise NotImplementedError( + "Failed to find the source code of the unresolved routine " + f"'{rsym.name}' - looked at any routines in the same source " + f"file and {wc_text}. Searching for external routines " + "that are only resolved at link time is not supported." + ) + + root_node = self._call_node.ancestor(Container) + if not root_node: + root_node = self._call_node.root + container = root_node + can_be_private = True + + if rsym.is_import: + cursor = rsym + # A Routine imported from another Container must be public in that + # Container. + can_be_private = False + while cursor.is_import: + csym = cursor.interface.container_symbol + try: + container = csym.find_container_psyir( + local_node=self._call_node) + except SymbolError: + raise NotImplementedError( + f"RoutineSymbol '{rsym.name}' is imported from " + f"Container '{csym.name}' but the source defining " + "that container could not be found. The module search" + f" path is set to {Config.get().include_paths}" + ) + imported_sym = container.symbol_table.lookup(cursor.name) + if imported_sym.visibility != Symbol.Visibility.PUBLIC: + # The required Symbol must be shadowed with a PRIVATE + # Symbol in this Container. This means that the one we + # actually want is brought into scope via a wildcard + # import. + # TODO #924 - Use ModuleManager to search? + raise NotImplementedError( + f"RoutineSymbol '{rsym.name}' is imported from " + f"Container '{csym.name}' but that Container defines " + "a private Symbol of the same name. Searching for the" + " Container that defines a public Routine with that " + "name is not yet supported - TODO #924" + ) + if not isinstance(imported_sym, RoutineSymbol): + # We now know that this is a RoutineSymbol so specialise it + # in place. + imported_sym.specialise(RoutineSymbol) + cursor = imported_sym + rsym = cursor + root_node = container + + if isinstance(rsym.datatype, UnsupportedFortranType): + # TODO #924 - an UnsupportedFortranType here typically indicates + # that the target is actually an interface. + raise NotImplementedError( + f"RoutineSymbol '{rsym.name}' exists in " + f"{_location_txt(root_node)} but is of " + f"UnsupportedFortranType:\n{rsym.datatype.declaration}\n" + "Cannot get the PSyIR of such a routine." + ) + + if isinstance(container, Container): + routines = [] + for name in container.resolve_routine(rsym.name): + psyir = container.find_routine_psyir( + name, allow_private=can_be_private + ) + if psyir: + routines.append(psyir) + if routines: + return routines + + raise SymbolError( + f"Failed to find a Routine named '{rsym.name}' in " + f"{_location_txt(root_node)}. This is normally because the routine" + " is within a CodeBlock." + ) + + def get_callee(self) -> Set[Union[Routine, List[int]]]: + ''' + Searches for the implementation(s) of the target routine for this Call + including argument checks. + + :returns: A tuple of two elements. The first element is the routine + that this call targets. The second one a list of arguments + providing the information on matching argument indices. + + :raises CallMatchingArgumentsNotFoundError: if the routine is not local + and not found in any containers in scope at the call site or if + the arguments don't match. + ''' + + routine_list = self.get_callee_candidates() + assert len(routine_list) != 0 + + err_info_list = [] + + # Search for the routine matching the right arguments + for routine_node in routine_list: + routine_node: Routine + self._routine_node = routine_node + + try: + self._arg_match_list = self.get_argument_routine_match_list() + except CallMatchingArgumentsNotFoundError as err: + err_info_list.append(err.value) + continue + + return (self._routine_node, self._arg_match_list) + + # If we didn't find any routine, return some routine if no matching + # arguments have been found. + # This is handy for the transition phase until optional argument + # matching is supported. + if not self._option_check_matching_arguments: + # Also return a list of dummy argument indices + self._routine_node = routine_list[0] + self._arg_match_list = list(range(len(self._call_node.arguments))) + return (self._routine_node, self._arg_match_list) + + error_msg = "\n".join(err_info_list) + + s = str(self._call_node.debug_string()).replace("\n", "") + raise CallMatchingArgumentsNotFoundError( + "Found routines, but no routine with matching arguments found " + f"for '{s}':\n" + + error_msg + ) diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 77084b9231..9732115eaf 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -37,7 +37,6 @@ This module contains the InlineTrans transformation. ''' -from psyclone.errors import LazyString from psyclone.psyGen import Transformation from psyclone.psyir.nodes import ( ArrayReference, ArrayOfStructuresReference, BinaryOperation, Call, @@ -45,14 +44,27 @@ Return, Literal, Statement, StructureMember, StructureReference) from psyclone.psyir.nodes.array_mixin import ArrayMixin from psyclone.psyir.symbols import ( - ArgumentInterface, ArrayType, DataSymbol, UnresolvedType, INTEGER_TYPE, - StaticInterface, SymbolError, UnknownInterface, - UnsupportedType, IntrinsicSymbol) + ArgumentInterface, + ArrayType, + DataSymbol, + INTEGER_TYPE, + StaticInterface, + SymbolError, + UnknownInterface, + UnsupportedType, + UnsupportedFortranType, + IntrinsicSymbol, +) from psyclone.psyir.transformations.reference2arrayrange_trans import ( Reference2ArrayRangeTrans) from psyclone.psyir.transformations.transformation_error import ( TransformationError) +from typing import Dict, List + +from psyclone.psyir.symbols import BOOLEAN_TYPE +from psyclone.psyir.symbols import ScalarType + _ONE = Literal("1", INTEGER_TYPE) @@ -122,47 +134,196 @@ class InlineTrans(Transformation): Some of these restrictions will be lifted by #924. ''' - def apply(self, node, options=None): - ''' + + def __init__(self): + # List of call-to-subroutine argument indices + self._ret_arg_match_list: List[int] = None + + # Call to routine + self._call_node: Call = None + + # Routine to be inlined for call + self._routine_node: Routine = None + + from psyclone.psyir.tools import CallRoutineMatcher + + self._call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher() + + # If 'True', don't inline if a code block is used within the + # Routine. + self._option_check_codeblocks: bool = True + + self._option_check_diff_container_clashes: bool = True + self._option_check_diff_container_clashes_unres_types: bool = True + self._option_check_resolve_imports: bool = True + self._option_check_static_interface: bool = True + self._option_check_array_type: bool = True + self._option_check_unsupported_type: bool = True + self._option_check_unresolved_symbols: bool = True + + def set_option( + self, + ignore_missing_modules: bool = None, + check_argument_strict_array_datatype: bool = None, + check_argument_matching: bool = None, + check_argument_ignore_unresolved_types: bool = None, + + check_inline_codeblocks: bool = None, + check_diff_container_clashes: bool = None, + check_diff_container_clashes_unres_types: bool = None, + check_resolve_imports: bool = None, + check_static_interface: bool = None, + check_array_type: bool = None, + check_unsupported_type: bool = None, + check_unresolved_symbols: bool = None, + ): + """Set special options + + :param ignore_missing_modules: If `True`, raise ModuleNotFound if + module is not available, defaults to None + :param check_argument_strict_array_datatype: + If `True`, make strict checks for matching arguments of + array data types. + If disabled, it's sufficient that both arguments are of ArrayType. + Then, no further checks are performed, defaults to None + :param check_argument_matching: If `True`, check for all arguments + to match. If `False`, if no matching argument was found, take + 1st one in list. Defaults to None + :param check_inline_codeblocks: If `True`, raise Exception + if encountering code blocks, defaults to None + :param check_diff_container_clashes: + If `True` and different symbols share a name but are imported + from different containers, raise Exception. + :param check_diff_container_clashes_unres_types: If `True`, + raise Exception if unresolved types are clashing, defaults to None + :param check_resolve_imports: If `True`, also resolve imports, + defaults to None + :param check_static_interface: + Check that there are no static variables in the routine + (because we don't know whether the routine is called from + other places). Defaults to None + :param check_array_type: If `True` and argument is an array, + check that inlining is working for this array type, + defaults to None + :param check_unsupported_type: If `True`, + also perform checks (fail inlining) on arguments of + unsupported type, defaults to None + :param check_argument_unresolved_symbols: If `True`, + stop if encountering an unresolved symbol, defaults to None + """ + + self._call_routine_matcher.set_option( + ignore_missing_modules=ignore_missing_modules, + check_strict_array_datatype=check_argument_strict_array_datatype, + check_matching_arguments=check_argument_matching, + ignore_unresolved_types=check_argument_ignore_unresolved_types + ) + + if check_inline_codeblocks is not None: + self._option_check_codeblocks = check_inline_codeblocks + + if check_diff_container_clashes is not None: + self._option_check_diff_container_clashes = ( + check_diff_container_clashes) + + if check_diff_container_clashes_unres_types is not None: + self._option_check_diff_container_clashes_unres_types = ( + check_diff_container_clashes_unres_types + ) + + if check_resolve_imports is not None: + self._option_check_resolve_imports = check_resolve_imports + + if check_static_interface is not None: + self._option_check_static_interface = check_static_interface + + if check_array_type is not None: + self._option_check_array_type = check_array_type + + if check_unsupported_type is not None: + self._option_check_unsupported_type = ( + check_unsupported_type + ) + + if check_unresolved_symbols is not None: + self._option_check_unresolved_symbols = ( + check_unresolved_symbols + ) + + def apply( + self, call_node: Call, routine_node: Routine = None, options=None + ): + """ Takes the body of the routine that is the target of the supplied call and replaces the call with it. - :param node: target PSyIR node. - :type node: :py:class:`psyclone.psyir.nodes.Routine` + :param call_node: target PSyIR node. + :type call_node: :py:class:`psyclone.psyir.nodes.Call` + :param routine: PSyIR subroutine to be inlined. + Default: Automatically determine subroutine (search) + :type routine: :py:class:`psyclone.psyir.nodes.Routine` :param options: a dictionary with options for transformations. :type options: Optional[Dict[str, Any]] :param bool options["force"]: whether or not to permit the inlining of Routines containing CodeBlocks. Default is False. - ''' - self.validate(node, options) + """ + + # Validate that the inlining can also be accomplish. + # This routine will also update + # self.node_routine and self._ret_arg_match_list + # with the routine to be inlined and the relation between the + # arguments and to which routine arguments they are matched to. + self.validate(call_node, routine_node=routine_node, options=options) + # The table associated with the scoping region holding the Call. - table = node.scope.symbol_table - # Find the routine to be inlined. - orig_routine = node.get_callees()[0] + table = call_node.scope.symbol_table - if not orig_routine.children or isinstance(orig_routine.children[0], - Return): + if not self._routine_node.children or isinstance( + self._routine_node.children[0], Return + ): # Called routine is empty so just remove the call. - node.detach() + call_node.detach() return # Ensure we don't modify the original Routine by working with a # copy of it. - routine = orig_routine.copy() - routine_table = routine.symbol_table + self._routine_node = self._routine_node.copy() + routine_table = self._routine_node.symbol_table + + # Next, we remove all optional arguments which are not used. + # Step 1) + # - Build lookup dictionary for all optional arguments: + + # - For all `PRESENT(...)`: + # - Lookup variable in dictionary + # - Replace with `True` or `False`, depending on whether + # it's provided or not. + self._optional_arg_resolve_present_intrinsics() + + # Step 2) + # - For all If-Statements, handle constant conditions: + # - `True`: Replace If-Block with If-Body + # - `False`: Replace If-Block with Else-Body. If it doesn't exist + # just delete the if statement. + self._optional_arg_eliminate_ifblock_if_const_condition() # Construct lists of the nodes that will be inserted and all of the # References that they contain. new_stmts = [] refs = [] - for child in routine.children: + for child in self._routine_node.children: + child: Node new_stmts.append(child.copy()) refs.extend(new_stmts[-1].walk(Reference)) # Shallow copy the symbols from the routine into the table at the # call site. - table.merge(routine_table, - symbols_to_skip=routine_table.argument_list[:]) + table.merge( + routine_table, + symbols_to_skip=routine_table.argument_list[:], + check_unresolved_symbols=( + self._option_check_unresolved_symbols), + ) # When constructing new references to replace references to formal # args, we need to know whether any of the actual arguments are array @@ -171,7 +332,7 @@ def apply(self, node, options=None): # as a Reference. ref2arraytrans = Reference2ArrayRangeTrans() - for child in node.arguments: + for child in call_node.arguments: try: # TODO #1858, this won't yet work for arrays inside structures. ref2arraytrans.apply(child) @@ -182,12 +343,12 @@ def apply(self, node, options=None): # actual arguments. formal_args = routine_table.argument_list for ref in refs[:]: - self._replace_formal_arg(ref, node, formal_args) + self._replace_formal_arg(ref, call_node, formal_args) # Store the Routine level symbol table and node's current scope # so we can merge symbol tables later if required. - ancestor_table = node.ancestor(Routine).scope.symbol_table - scope = node.scope + ancestor_table = call_node.ancestor(Routine).scope.symbol_table + scope = call_node.scope # Copy the nodes from the Routine into the call site. # TODO #924 - while doing this we should ensure that any References @@ -198,9 +359,9 @@ def apply(self, node, options=None): # remove it from the list. del new_stmts[-1] - if routine.return_symbol: + if self._routine_node.return_symbol: # This is a function - assignment = node.ancestor(Statement) + assignment = call_node.ancestor(Statement) parent = assignment.parent idx = assignment.position-1 for child in new_stmts: @@ -209,14 +370,17 @@ def apply(self, node, options=None): table = parent.scope.symbol_table # Avoid a potential name clash with the original function table.rename_symbol( - routine.return_symbol, table.next_available_name( - f"inlined_{routine.return_symbol.name}")) - node.replace_with(Reference(routine.return_symbol)) + self._routine_node.return_symbol, + table.next_available_name( + f"inlined_{self._routine_node.return_symbol.name}" + ), + ) + call_node.replace_with(Reference(self._routine_node.return_symbol)) else: # This is a call - parent = node.parent - idx = node.position - node.replace_with(new_stmts[0]) + parent = call_node.parent + idx = call_node.position + call_node.replace_with(new_stmts[0]) for child in new_stmts[1:]: idx += 1 parent.addchild(child, idx) @@ -226,11 +390,130 @@ def apply(self, node, options=None): # the ancestor Routine. This avoids issues like #2424 when # applying ParallelLoopTrans to loops containing inlined calls. if ancestor_table is not scope.symbol_table: - ancestor_table.merge(scope.symbol_table) + ancestor_table.merge( + scope.symbol_table, + check_unresolved_symbols=( + self._option_check_unresolved_symbols)) replacement = type(scope.symbol_table)() scope.symbol_table.detach() replacement.attach(scope) + def _optional_arg_resolve_present_intrinsics(self): + """Replace PRESENT(some_argument) intrinsics in routine with constant + booleans depending on whether `some_argument` has been provided + (`True`) or not (`False`). + + :rtype: None + """ + # We first build a lookup table of all optional arguments + # to see whether it's present or not. + optional_sym_present_dict: Dict[str, bool] = dict() + for optional_arg_idx, datasymbol in enumerate( + self._routine_node.symbol_table.datasymbols + ): + if not isinstance(datasymbol.datatype, UnsupportedFortranType): + continue + + if ", OPTIONAL" not in str(datasymbol.datatype): + continue + + sym_name = datasymbol.name.lower() + + if optional_arg_idx not in self._ret_arg_match_list: + optional_sym_present_dict[sym_name] = False + else: + optional_sym_present_dict[sym_name] = True + + # Check if we have any optional arguments at all and if not, return + if len(optional_sym_present_dict) == 0: + return + + # Find all "PRESENT()" calls + for intrinsic_call in self._routine_node.walk(IntrinsicCall): + intrinsic_call: IntrinsicCall + if intrinsic_call.routine.name.lower() == "present": + + # The argument is in the 2nd child + present_arg: Reference = intrinsic_call.children[1] + present_arg_name = present_arg.name.lower() + + assert present_arg_name in optional_sym_present_dict + + if optional_sym_present_dict[present_arg_name]: + # The argument is present. + intrinsic_call.replace_with(Literal("true", BOOLEAN_TYPE)) + else: + intrinsic_call.replace_with(Literal("false", BOOLEAN_TYPE)) + + def _optional_arg_eliminate_ifblock_if_const_condition(self): + """Eliminate if-block if conditions are constant booleans. + + TODO: This also requires support of conditions containing logical + expressions such as `(.true. .or. .false.)` + TODO: This could also become a Psyclone transformation. + + :rtype: None + """ + + def if_else_replace(main_schedule, if_block, if_body_schedule): + """Little helper routine to eliminate one branch of an IfBlock + + :param main_schedule: Schedule where if-branch is used + :type main_schedule: Schedule + :param if_block: If-else block itself + :type if_block: IfBlock + :param if_body_schedule: The body of the if or else block + :type if_body_schedule: Schedule + """ + + from psyclone.psyir.nodes import Schedule + + assert isinstance(main_schedule, Schedule) + assert isinstance(if_body_schedule, Schedule) + + # Obtain index in main schedule + idx = main_schedule.children.index(if_block) + + # Detach it + if_block.detach() + + # Insert childreen of if-body schedule + for child in if_body_schedule.children: + main_schedule.addchild(child.copy(), idx) + idx += 1 + + from psyclone.psyir.nodes import IfBlock + + for if_block in self._routine_node.walk(IfBlock): + if_block: IfBlock + + condition = if_block.condition + + # Make sure we only handle a BooleanLiteral as a condition + # TODO #2802 + if not isinstance(condition, Literal): + continue + + # Check that it's a boolean Literal + assert ( + condition.datatype.intrinsic + is ScalarType.Intrinsic.BOOLEAN + ), "Found non-boolean expression in conditional of if branch" + + if condition.value == "true": + # Only keep if_block + if_else_replace(if_block.parent, if_block, if_block.if_body) + + else: + # If there's an else block, replace if-condition with + # else-block + if not if_block.else_body: + if_block.detach() + continue + + # Only keep else block + if_else_replace(if_block.parent, if_block, if_block.else_body) + def _replace_formal_arg(self, ref, call_node, formal_args): ''' Recursively combines any References to formal arguments in the supplied @@ -260,8 +543,26 @@ def _replace_formal_arg(self, ref, call_node, formal_args): # The supplied reference is not to a formal argument. return ref + # Lookup index in routine argument + routine_arg_idx = formal_args.index(ref.symbol) + + # Lookup index of actual argument + # If this is an optional argument, but not used, this index lookup + # shouldn't fail + try: + actual_arg_idx = self._ret_arg_match_list.index(routine_arg_idx) + except ValueError as err: + arg_list = self._routine_node.symbol_table.argument_list + arg_name = arg_list[routine_arg_idx].name + raise TransformationError( + f"Subroutine argument '{arg_name}' is not provided by call," + f" but used in the subroutine." + f" If this is correct code, this is likely due to" + f" some non-eliminated if-branches using `PRESENT(...)` as" + f" conditional (TODO #2802).") from err + # Lookup the actual argument that corresponds to this formal argument. - actual_arg = call_node.arguments[formal_args.index(ref.symbol)] + actual_arg = call_node.arguments[actual_arg_idx] # If the local reference is a simple Reference then we can just # replace it with a copy of the actual argument, e.g. @@ -579,140 +880,257 @@ def _replace_formal_struc_arg(self, actual_arg, ref, call_node, # Just an array reference. return ArrayReference.create(actual_arg.symbol, members[0][1]) - def validate(self, node, options=None): - ''' - Checks that the supplied node is a valid target for inlining. - - :param node: target PSyIR node. - :type node: subclass of :py:class:`psyclone.psyir.nodes.Call` - :param options: a dictionary with options for transformations. - :type options: Optional[Dict[str, Any]] - :param bool options["force"]: whether or not to ignore any CodeBlocks - in the candidate routine. Default is False. - - :raises TransformationError: if the supplied node is not a Call or is - an IntrinsicCall or call to a PSyclone-generated routine. - :raises TransformationError: if the routine has a return value. - :raises TransformationError: if the routine body contains a Return - that is not the first or last statement. - :raises TransformationError: if the routine body contains a CodeBlock - and the 'force' option is not True. - :raises TransformationError: if the called routine has a named - argument. - :raises TransformationError: if any of the variables declared within - the called routine are of UnknownInterface. - :raises TransformationError: if any of the variables declared within - the called routine have a StaticInterface. - :raises TransformationError: if any of the subroutine arguments is of - UnsupportedType. - :raises TransformationError: if a symbol of a given name is imported - from different containers at the call site and within the routine. - :raises TransformationError: if the routine accesses an un-resolved - symbol. - :raises TransformationError: if the number of arguments in the call - does not match the number of formal arguments of the routine. - :raises TransformationError: if a symbol declared in the parent - container is accessed in the target routine. - :raises TransformationError: if the shape of an array formal argument - does not match that of the corresponding actual argument. - - ''' - super().validate(node, options=options) - - options = {} if options is None else options - forced = options.get("force", False) - - # The node should be a Call. - if not isinstance(node, Call): - raise TransformationError( - f"The target of the InlineTrans transformation " - f"should be a Call but found '{type(node).__name__}'.") - - if isinstance(node, IntrinsicCall): - raise TransformationError( - f"Cannot inline an IntrinsicCall ('{node.routine.name}')") - name = node.routine.name - - # Check that we can find the source of the routine being inlined. - # TODO #924 allow for multiple routines (interfaces). - try: - routine = node.get_callees()[0] - except (NotImplementedError, FileNotFoundError, SymbolError) as err: + def _validate_inline_of_call_and_routine_argument_pairs( + self, + call_arg: DataSymbol, + routine_arg: DataSymbol + ) -> bool: + """This function performs tests to see whether the + inlining can cope with it. + + :param call_arg: The argument of a call + :type call_arg: DataSymbol + :param routine_arg: The argument of a routine + :type routine_arg: DataSymbol + + :raises TransformationError: Raised if transformation can't be done + + :return: 'True' if checks are successful + :rtype: bool + """ + from psyclone.psyir.transformations.transformation_error import ( + TransformationError, + ) + from psyclone.errors import LazyString + from psyclone.psyir.nodes import Literal, Range + from psyclone.psyir.symbols import ( + UnresolvedType, + UnsupportedType, + INTEGER_TYPE, + ) + + _ONE = Literal("1", INTEGER_TYPE) + + # If the formal argument is an array with non-default bounds then + # we also need to know the bounds of that array at the call site. + if not isinstance(routine_arg.datatype, ArrayType): + # Formal argument is not an array so we don't need to do any + # further checks. + return True + + if not isinstance(call_arg, (Reference, Literal)): + # TODO #1799 this really needs the `datatype` method to be + # extended to support all nodes. For now we have to abort + # if we encounter an argument that is not a scalar (according + # to the corresponding formal argument) but is not a + # Reference or a Literal as we don't know whether the result + # of any general expression is or is not an array. + # pylint: disable=cell-var-from-loop raise TransformationError( - f"Cannot inline routine '{name}' because its source cannot be " - f"found: {err}") from err + LazyString( + lambda: ( + f"The call '{self._call_node.debug_string()}' " + "cannot be inlined because actual argument " + f"'{call_arg.debug_string()}' corresponds to a " + "formal argument with array type but is not a " + "Reference or a Literal." + ) + ) + ) + + # We have an array argument. We are only able to check that the + # argument is not re-shaped in the called routine if we have full + # type information on the actual argument. + # TODO #924. It would be useful if the `datatype` property was + # a method that took an optional 'resolve' argument to indicate + # that it should attempt to resolve any UnresolvedTypes. + if self._option_check_array_type: + if isinstance( + call_arg.datatype, (UnresolvedType, UnsupportedType) + ) or ( + isinstance(call_arg.datatype, ArrayType) + and isinstance( + call_arg.datatype.intrinsic, + (UnresolvedType, UnsupportedType), + ) + ): + raise TransformationError( + f"Routine '{self._routine_node.name}' cannot be " + "inlined because the type of the actual argument " + f"'{call_arg.symbol.name}' corresponding to an array" + f" formal argument ('{routine_arg.name}') is unknown." + ) - if not routine.children or isinstance(routine.children[0], Return): + formal_rank = 0 + actual_rank = 0 + if isinstance(routine_arg.datatype, ArrayType): + formal_rank = len(routine_arg.datatype.shape) + if isinstance(call_arg.datatype, ArrayType): + actual_rank = len(call_arg.datatype.shape) + if formal_rank != actual_rank: + # It's OK to use the loop variable in the lambda definition + # because if we get to this point then we're going to quit + # the loop. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because it" + " reshapes an argument: actual argument" + f" '{call_arg.debug_string()}' has rank" + f" {actual_rank} but the corresponding formal" + f" argument, '{routine_arg.name}', has rank" + f" {formal_rank}" + ) + ) + ) + if actual_rank: + ranges = call_arg.walk(Range) + for rge in ranges: + ancestor_ref = rge.ancestor(Reference) + if ancestor_ref is not call_arg: + # Have a range in an indirect access. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because" + " argument" + f" '{call_arg.debug_string()}' has" + " an array range in an indirect" + " access #(TODO 924)." + ) + ) + ) + if rge.step != _ONE: + # TODO #1646. We could resolve this problem by + # making a new array and copying the necessary + # values into it. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because" + " one of its arguments is an array" + " slice with a non-unit stride:" + f" '{call_arg.debug_string()}' (TODO" + " #1646)" + ) + ) + ) + + def _validate_inline_of_call_and_routine( + self, + call_node: Call, + routine_node: Routine, + arg_index_list: List[int] + ): + """Performs various checks that the inlining is supported for the + combination of the call's and routine's arguments. + + :param call_node: Call to be replaced by the inlined Routine + :type call_node: Call + :param routine_node: Routine to be inlined + :type routine_node: Routine + :param arg_index_list: Argument index list to match the arguments of + the call to those of the routine in case of optional arguments. + :type arg_index_list: List[int] + :raises TransformationError: Arguments are not in a form to be inlined + + """ + + name = call_node.routine.name + + if not routine_node.children or isinstance( + routine_node.children[0], Return + ): # An empty routine is fine. return - return_stmts = routine.walk(Return) + return_stmts = routine_node.walk(Return) if return_stmts: - if len(return_stmts) > 1 or not isinstance(routine.children[-1], - Return): + if len(return_stmts) > 1 or not isinstance( + routine_node.children[-1], Return + ): # Either there is more than one Return statement or there is # just one but it isn't the last statement of the Routine. raise TransformationError( f"Routine '{name}' contains one or more " f"Return statements and therefore cannot be inlined.") - if routine.walk(CodeBlock) and not forced: - # N.B. we permit the user to specify the "force" option to allow - # CodeBlocks to be included. - raise TransformationError( - f"Routine '{name}' contains one or more CodeBlocks and " - "therefore cannot be inlined. (If you are confident that " - "the code may safely be inlined despite this then use " - "`options={'force': True}` to override.)") - - # Support for routines with named arguments is not yet implemented. - # TODO #924. - for arg in node.argument_names: - if arg: + if self._option_check_codeblocks: + if routine_node.walk(CodeBlock): + # N.B. we permit the user to specify the "force" option to + # allow CodeBlocks to be included. raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because it " - f"has a named argument '{arg}' (TODO #924).") + f"Routine '{name}' contains one or more CodeBlocks and " + "therefore cannot be inlined. (If you are confident that " + "the code may safely be inlined despite this then use " + "`check_codeblocks=False` to override.)" + ) - table = node.scope.symbol_table - routine_table = routine.symbol_table + table = call_node.scope.symbol_table + routine_table = routine_node.symbol_table for sym in routine_table.datasymbols: # We don't inline symbols that have an UnsupportedType and are # arguments since we don't know if a simple assignment if # enough (e.g. pointers) - if isinstance(sym.interface, ArgumentInterface): - if isinstance(sym.datatype, UnsupportedType): + if self._option_check_unsupported_type: + if isinstance(sym.interface, ArgumentInterface): + if isinstance(sym.datatype, UnsupportedType): + if ", OPTIONAL" not in sym.datatype.declaration: + raise TransformationError( + f"Routine '{routine_node.name}' cannot be" + " inlined because it contains a Symbol" + f" '{sym.name}' which is an Argument of" + " UnsupportedType:" + f" '{sym.datatype.declaration}'." + ) + # We don't inline symbols that have an UnknownInterface, as we + # don't know how they are brought into this scope. + if isinstance(sym.interface, UnknownInterface): raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because " - f"it contains a Symbol '{sym.name}' which is an " - f"Argument of UnsupportedType: " - f"'{sym.datatype.declaration}'") - # We don't inline symbols that have an UnknownInterface, as we - # don't know how they are brought into this scope. - if isinstance(sym.interface, UnknownInterface): - raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because it " - f"contains a Symbol '{sym.name}' with an UnknownInterface:" - f" '{sym.datatype.declaration}'") - # Check that there are no static variables in the routine (because - # we don't know whether the routine is called from other places). - if (isinstance(sym.interface, StaticInterface) and - not sym.is_constant): + f"Routine '{routine_node.name}' cannot be " + "inlined because it contains a Symbol " + f"'{sym.name}' with an UnknownInterface: " + f"'{sym.datatype.declaration}'." + ) + + if self._option_check_static_interface: + # Check that there are no static variables in the routine + # (because we don't know whether the routine is called from + # other places). + if ( + isinstance(sym.interface, StaticInterface) + and not sym.is_constant + ): + raise TransformationError( + f"Routine '{routine_node.name}' cannot be " + "inlined because it has a static (Fortran SAVE) " + f"interface for Symbol '{sym.name}'." + ) + + if self._option_check_diff_container_clashes: + # We can't handle a clash between (apparently) different symbols + # that share a name but are imported from different containers. + try: + table.check_for_clashes( + routine_table, + symbols_to_skip=routine_table.argument_list[:], + check_unresolved_symbols=( + self._option_check_diff_container_clashes_unres_types + ), + ) + except SymbolError as err: raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because it " - f"has a static (Fortran SAVE) interface for Symbol " - f"'{sym.name}'.") - - # We can't handle a clash between (apparently) different symbols that - # share a name but are imported from different containers. - try: - table.check_for_clashes( - routine_table, - symbols_to_skip=routine_table.argument_list[:]) - except SymbolError as err: - raise TransformationError( - f"One or more symbols from routine '{routine.name}' cannot be " - f"added to the table at the call site.") from err + "One or more symbols from routine " + f"'{routine_node.name}' cannot be added to the " + "table at the call site." + ) from err # Check for unresolved symbols or for any accessed from the Container # containing the target routine. @@ -722,22 +1140,25 @@ def validate(self, node, options=None): # that are used to define the precision of other Symbols in the same # table. If a precision symbol is only used within Statements then we # don't currently capture the fact that it is a precision symbol. - ref_or_lits = routine.walk((Reference, Literal)) + ref_or_lits = routine_node.walk((Reference, Literal)) # Check for symbols in any initial-value expressions # (including Fortran parameters) or array dimensions. for sym in routine_table.datasymbols: if sym.initial_value: ref_or_lits.extend( - sym.initial_value.walk((Reference, Literal))) + sym.initial_value.walk((Reference, Literal)) + ) if isinstance(sym.datatype, ArrayType): for dim in sym.shape: if isinstance(dim, ArrayType.ArrayBounds): if isinstance(dim.lower, Node): - ref_or_lits.extend(dim.lower.walk(Reference, - Literal)) + ref_or_lits.extend( + dim.lower.walk(Reference, Literal) + ) if isinstance(dim.upper, Node): - ref_or_lits.extend(dim.upper.walk(Reference, - Literal)) + ref_or_lits.extend( + dim.upper.walk(Reference, Literal) + ) # Keep a reference to each Symbol that we check so that we can avoid # repeatedly checking the same Symbol. _symbol_cache = set() @@ -754,115 +1175,154 @@ def validate(self, node, options=None): _symbol_cache.add(sym) if isinstance(sym, IntrinsicSymbol): continue - # We haven't seen this Symbol before. - if sym.is_unresolved: - try: - routine_table.resolve_imports(symbol_target=sym) - except KeyError: - # The symbol is not (directly) imported into the symbol - # table local to the routine. - # pylint: disable=raise-missing-from - raise TransformationError( - f"Routine '{routine.name}' cannot be inlined " - f"because it accesses variable '{sym.name}' and this " - f"cannot be found in any of the containers directly " - f"imported into its symbol table.") - else: - if sym.name not in routine_table: - raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because " - f"it accesses variable '{sym.name}' from its " - f"parent container.") - - # Check that the shapes of any formal array arguments are the same as - # those at the call site. - if len(routine_table.argument_list) != len(node.arguments): - raise TransformationError(LazyString( - lambda: f"Cannot inline '{node.debug_string().strip()}' " - f"because the number of arguments supplied to the call " - f"({len(node.arguments)}) does not match the number of " - f"arguments the routine is declared to have " - f"({len(routine_table.argument_list)}).")) - - for formal_arg, actual_arg in zip(routine_table.argument_list, - node.arguments): - # If the formal argument is an array with non-default bounds then - # we also need to know the bounds of that array at the call site. - if not isinstance(formal_arg.datatype, ArrayType): - # Formal argument is not an array so we don't need to do any - # further checks. - continue - if not isinstance(actual_arg, (Reference, Literal)): - # TODO #1799 this really needs the `datatype` method to be - # extended to support all nodes. For now we have to abort - # if we encounter an argument that is not a scalar (according - # to the corresponding formal argument) but is not a - # Reference or a Literal as we don't know whether the result - # of any general expression is or is not an array. - # pylint: disable=cell-var-from-loop - raise TransformationError(LazyString( - lambda: f"The call '{node.debug_string()}' cannot be " - f"inlined because actual argument " - f"'{actual_arg.debug_string()}' corresponds to a " - f"formal argument with array type but is not a " - f"Reference or a Literal.")) - - # We have an array argument. We are only able to check that the - # argument is not re-shaped in the called routine if we have full - # type information on the actual argument. - # TODO #924. It would be useful if the `datatype` property was - # a method that took an optional 'resolve' argument to indicate - # that it should attempt to resolve any UnresolvedTypes. - if (isinstance(actual_arg.datatype, - (UnresolvedType, UnsupportedType)) or - (isinstance(actual_arg.datatype, ArrayType) and - isinstance(actual_arg.datatype.intrinsic, - (UnresolvedType, UnsupportedType)))): + if self._option_check_resolve_imports: + # We haven't seen this Symbol before. + if sym.is_unresolved: + try: + routine_table.resolve_imports(symbol_target=sym) + except KeyError: + # The symbol is not (directly) imported into the symbol + # table local to the routine. + # pylint: disable=raise-missing-from + raise TransformationError( + f"Routine '{routine_node.name}' cannot be " + "inlined because it accesses variable " + f"'{sym.name}' and this cannot be found in any " + "of the containers directly imported into its " + "symbol table." + ) + else: + if sym.name not in routine_table: + raise TransformationError( + f"Routine '{routine_node.name}' cannot be " + "inlined because it accesses variable " + f"'{sym.name}' from its parent container." + ) + + # Create a list of routine arguments that is actually used + routine_arg_list = [ + routine_table.argument_list[i] for i in arg_index_list + ] + + for routine_arg, call_arg in zip( + routine_arg_list, call_node.arguments + ): + self._validate_inline_of_call_and_routine_argument_pairs( + call_arg, + routine_arg + ) + + def validate( + self, + call_node: Call, + routine_node: Routine = None, + options: Dict[str, str] = None, + ): + """ + Checks that the supplied node is a valid target for inlining. + + :param call_node: target PSyIR node. + :type call_node: subclass of :py:class:`psyclone.psyir.nodes.Call` + :param routine_node: Routine to inline. + Default is to search for it. + :type routine_node: subclass of :py:class:`Routine` + :param options: a dictionary with options for transformations. + :type options: Optional[Dict[str, Any]] + :param bool options["force"]: whether or not to ignore any CodeBlocks + in the candidate routine. Default is False. + + :raises TransformationError: if the supplied node is not a Call or is + an IntrinsicCall or call to a PSyclone-generated routine. + :raises TransformationError: if the routine has a return value. + :raises TransformationError: if the routine body contains a Return + that is not the first or last statement. + :raises TransformationError: if the routine body contains a CodeBlock + and the 'force' option is not True. + :raises TransformationError: if the called routine has a named + argument. + :raises TransformationError: if any of the variables declared within + the called routine are of UnknownInterface. + :raises TransformationError: if any of the variables declared within + the called routine have a StaticInterface. + :raises TransformationError: if any of the subroutine arguments is of + UnsupportedType. + :raises TransformationError: if a symbol of a given name is imported + from different containers at the call site and within the routine. + :raises TransformationError: if the routine accesses an un-resolved + symbol. + :raises TransformationError: if the number of arguments in the call + does not match the number of formal arguments of the routine. + :raises TransformationError: if a symbol declared in the parent + container is accessed in the target routine. + :raises TransformationError: if the shape of an array formal argument + does not match that of the corresponding actual argument. + + """ + super().validate(call_node, options=options) + + self._call_node = call_node + self._routine_node = routine_node + + # The node should be a Call. + if not isinstance(self._call_node, Call): + raise TransformationError( + "The target of the InlineTrans transformation should" + f" be a Call but found '{type(self._call_node).__name__}'." + ) + + call_name = self._call_node.routine.name + if isinstance(self._call_node, IntrinsicCall): + raise TransformationError( + f"Cannot inline an IntrinsicCall ('{call_name}')" + ) + + # List of indices relating the call's arguments to the subroutine + # arguments. This can be different due to + # - optional arguments + # - named arguments + + from psyclone.psyir.tools import CallMatchingArgumentsNotFoundError + + self._call_routine_matcher.set_call_node(self._call_node) + + if self._routine_node is None: + # Check that we can find the source of the routine being inlined. + # TODO #924 allow for multiple routines (interfaces). + try: + (self._routine_node, self._ret_arg_match_list) = \ + self._call_routine_matcher.get_callee() + except ( + CallMatchingArgumentsNotFoundError, + NotImplementedError, + FileNotFoundError, + SymbolError, + TransformationError, + ) as err: raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because " - f"the type of the actual argument " - f"'{actual_arg.symbol.name}' corresponding to an array" - f" formal argument ('{formal_arg.name}') is unknown.") + f"Cannot inline routine '{call_name}' because its source" + f" cannot be found:\n{str(err)}" + ) from err - formal_rank = 0 - actual_rank = 0 - if isinstance(formal_arg.datatype, ArrayType): - formal_rank = len(formal_arg.datatype.shape) - if isinstance(actual_arg.datatype, ArrayType): - actual_rank = len(actual_arg.datatype.shape) - if formal_rank != actual_rank: - # It's OK to use the loop variable in the lambda definition - # because if we get to this point then we're going to quit - # the loop. - # pylint: disable=cell-var-from-loop - raise TransformationError(LazyString( - lambda: f"Cannot inline routine '{routine.name}' " - f"because it reshapes an argument: actual argument " - f"'{actual_arg.debug_string()}' has rank {actual_rank}" - f" but the corresponding formal argument, " - f"'{formal_arg.name}', has rank {formal_rank}")) - if actual_rank: - ranges = actual_arg.walk(Range) - for rge in ranges: - ancestor_ref = rge.ancestor(Reference) - if ancestor_ref is not actual_arg: - # Have a range in an indirect access. - # pylint: disable=cell-var-from-loop - raise TransformationError(LazyString( - lambda: f"Cannot inline routine '{routine.name}' " - f"because argument '{actual_arg.debug_string()}' " - f"has an array range in an indirect access (TODO " - f"#924).")) - if rge.step != _ONE: - # TODO #1646. We could resolve this problem by making - # a new array and copying the necessary values into it. - # pylint: disable=cell-var-from-loop - raise TransformationError(LazyString( - lambda: f"Cannot inline routine '{routine.name}' " - f"because one of its arguments is an array slice " - f"with a non-unit stride: " - f"'{actual_arg.debug_string()}' (TODO #1646)")) + else: + # A routine has been provided. + # Therefore, we just determine the matching argument list + # if it matches. + try: + rm = self._call_routine_matcher + rm.set_routine_node(self._routine_node) + rm.set_option( + check_strict_array_datatype=False) + self._ret_arg_match_list = ( + rm.get_argument_routine_match_list() + ) + except CallMatchingArgumentsNotFoundError as err: + raise TransformationError( + "Routine's argument(s) don't match:\n"+str(err) + ) from err + + self._validate_inline_of_call_and_routine( + call_node, self._routine_node, self._ret_arg_match_list) # For AutoAPI auto-documentation generation. diff --git a/src/psyclone/psyir/transformations/omp_task_trans.py b/src/psyclone/psyir/transformations/omp_task_trans.py index e3b044ad4e..6a3d2f2e3c 100644 --- a/src/psyclone/psyir/transformations/omp_task_trans.py +++ b/src/psyclone/psyir/transformations/omp_task_trans.py @@ -58,6 +58,23 @@ class OMPTaskTrans(ParallelLoopTrans): implementation. ''' + def __init__(self): + super().__init__() + + # If 'True', the callee must have matching arguments. + # The 'matching' criteria can be weakened by other options. + # If 'False', in case no match was found, the first callee is taken. + self._option_check_matching_arguments_of_callee: bool = True + + def set_option( + self, + check_matching_arguments_of_callee: bool = None, + ): + if check_matching_arguments_of_callee is not None: + self._option_check_matching_arguments_of_callee = ( + check_matching_arguments_of_callee + ) + def __str__(self): return "Adds an 'OMP TASK' directive to a statement" @@ -98,6 +115,11 @@ def validate(self, node, options=None): kintrans = KernelModuleInlineTrans() cond_trans = FoldConditionalReturnExpressionsTrans() intrans = InlineTrans() + intrans.set_option( + check_argument_matching=( + self._option_check_matching_arguments_of_callee + ) + ) for kern in kerns: kintrans.validate(kern) cond_trans.validate(kern.get_kernel_schedule()) @@ -157,6 +179,11 @@ def _inline_kernels(self, node): kintrans = KernelModuleInlineTrans() cond_trans = FoldConditionalReturnExpressionsTrans() intrans = InlineTrans() + intrans.set_option( + check_argument_matching=( + self._option_check_matching_arguments_of_callee + ) + ) for kern in kerns: kintrans.apply(kern) cond_trans.apply(kern.get_kernel_schedule()) diff --git a/src/psyclone/tests/psyir/nodes/call_test.py b/src/psyclone/tests/psyir/nodes/call_test.py index 667a8ee682..d07cb73aea 100644 --- a/src/psyclone/tests/psyir/nodes/call_test.py +++ b/src/psyclone/tests/psyir/nodes/call_test.py @@ -41,15 +41,27 @@ from psyclone.configuration import Config from psyclone.core import Signature, VariablesAccessInfo from psyclone.errors import GenerationError -from psyclone.parse import ModuleManager from psyclone.psyir.nodes import ( - ArrayReference, Assignment, BinaryOperation, Call, CodeBlock, Literal, - Node, Reference, Routine, Schedule) -from psyclone.psyir.nodes.call import CallMatchingArgumentsNotFound + ArrayReference, + BinaryOperation, + Call, + Literal, + Reference, + Routine, + Schedule, +) from psyclone.psyir.nodes.node import colored from psyclone.psyir.symbols import ( - ArrayType, INTEGER_TYPE, DataSymbol, NoType, RoutineSymbol, REAL_TYPE, - SymbolError, UnsupportedFortranType) + ArrayType, + INTEGER_TYPE, + DataSymbol, + NoType, + RoutineSymbol, + REAL_TYPE +) + +from psyclone.psyir.tools.call_routine_matcher import ( + CallMatchingArgumentsNotFoundError) class SpecialCall(Call): @@ -604,632 +616,7 @@ def test_copy(): assert call._argument_names != call2._argument_names -def test_call_get_callees_local(fortran_reader): - ''' - Check that get_callees() works as expected when the target of the Call - exists in the same Container as the call site. - ''' - code = ''' -module some_mod - implicit none - integer :: luggage -contains - subroutine top() - luggage = 0 - call bottom() - end subroutine top - - subroutine bottom() - luggage = luggage + 1 - end subroutine bottom -end module some_mod''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - result = call.get_callees() - assert result == [psyir.walk(Routine)[1]] - - -def test_call_get_callee_1_simple_match(fortran_reader): - ''' - Check that the right routine has been found for a single routine - implementation. - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(e, f, g) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b, c - end subroutine - -end module some_mod''' - - psyir = fortran_reader.psyir_from_source(code) - - routine_main: Routine = psyir.walk(Routine)[0] - assert routine_main.name == "main" - - call_foo: Call = routine_main.walk(Call)[0] - - (result, _) = call_foo.get_callee() - - routine_match: Routine = psyir.walk(Routine)[1] - assert result is routine_match - - -def test_call_get_callee_2_optional_args(fortran_reader): - ''' - Check that optional arguments have been correlated correctly. - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(e, f) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b - integer, optional :: c - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_match: Routine = root_node.walk(Routine)[1] - assert routine_match.name == "foo" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - (result, arg_idx_list) = call_foo.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - - assert result is routine_match - - -def test_call_get_callee_3a_trigger_error(fortran_reader): - ''' - Test which is supposed to trigger an error when no matching routine - is found - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(e, f, g) - end subroutine - - ! Matching routine - subroutine foo(a, b) - integer :: a, b - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - with pytest.raises(CallMatchingArgumentsNotFound) as err: - call_foo.get_callee() - - assert "No matching routine found for" in str(err.value) - - -def test_call_get_callee_3c_trigger_error(fortran_reader): - ''' - Test which is supposed to trigger an error when no matching routine - is found, but we use the special option check_matching_arguments=False - to find one. - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(e, f, g) - end subroutine - - ! Matching routine - subroutine foo(a, b) - integer :: a, b - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - call_foo.get_callee(check_matching_arguments=False) - - -def test_call_get_callee_4_named_arguments(fortran_reader): - ''' - Check that named arguments have been correlated correctly - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(c=e, a=f, b=g) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b, c - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_match: Routine = root_node.walk(Routine)[1] - assert routine_match.name == "foo" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - (result, arg_idx_list) = call_foo.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 2 - assert arg_idx_list[1] == 0 - assert arg_idx_list[2] == 1 - - assert result is routine_match - - -def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): - ''' - Check that optional and named arguments have been correlated correctly - when the call is to a generic interface. - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(b=e, a=f) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b - integer, optional :: c - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_match: Routine = root_node.walk(Routine)[1] - assert routine_match.name == "foo" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - (result, arg_idx_list) = call_foo.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 1 - assert arg_idx_list[1] == 0 - - assert result is routine_match - - -_code_test_get_callee_6 = ''' -module some_mod - implicit none - - interface foo - procedure foo_a, foo_b, foo_c, foo_optional - end interface -contains - - subroutine main() - integer :: e_int, f_int, g_int - real :: e_real, f_real, g_real - - ! Should match foo_a, test_call_get_callee_6_interfaces_0_0 - call foo(e_int, f_int) - - ! Should match foo_a, test_call_get_callee_6_interfaces_0_1 - call foo(e_int, f_int, g_int) - - ! Should match foo_b, test_call_get_callee_6_interfaces_1_0 - call foo(e_real, f_int) - - ! Should match foo_b, test_call_get_callee_6_interfaces_1_1 - call foo(e_real, f_int, g_int) - - ! Should match foo_b, test_call_get_callee_6_interfaces_1_2 - call foo(e_real, c=f_int, b=g_int) - - ! Should match foo_c, test_call_get_callee_6_interfaces_2_0 - call foo(e_int, f_real, g_int) - - ! Should match foo_c, test_call_get_callee_6_interfaces_2_1 - call foo(b=e_real, a=f_int) - - ! Should match foo_c, test_call_get_callee_6_interfaces_2_2 - call foo(b=e_real, a=f_int, g_int) - - ! Should not match foo_optional because of invalid type, - ! test_call_get_callee_6_interfaces_3_0_mismatch - call foo(f_int, e_real, g_int, g_int) - end subroutine - - subroutine foo_a(a, b, c) - integer :: a, b - integer, optional :: c - end subroutine - - subroutine foo_b(a, b, c) - real :: a - integer :: b - integer, optional :: c - end subroutine - - subroutine foo_c(a, b, c) - integer :: a - real :: b - integer, optional :: c - end subroutine - - subroutine foo_optional(a, b, c, d) - integer :: a - real :: b - integer :: c - real, optional :: d ! real vs. int - end subroutine - - -end module some_mod''' - - -def test_call_get_callee_6_interfaces_0_0(fortran_reader): - ''' - Check that a non-existing optional argument at the end of the list - has been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_a: Routine = root_node.walk(Routine)[1] - assert routine_foo_a.name == "foo_a" - - call_foo_a: Call = routine_main.walk(Call)[0] - assert call_foo_a.routine.name == "foo" - - (result, arg_idx_list) = call_foo_a.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - - assert result is routine_foo_a - - -def test_call_get_callee_6_interfaces_0_1(fortran_reader): - ''' - Check that an existing optional argument at the end of the list - has been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_a: Routine = root_node.walk(Routine)[1] - assert routine_foo_a.name == "foo_a" - - call_foo_a: Call = routine_main.walk(Call)[1] - assert call_foo_a.routine.name == "foo" - - (result, arg_idx_list) = call_foo_a.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - assert arg_idx_list[2] == 2 - - assert result is routine_foo_a - - -def test_call_get_callee_6_interfaces_1_0(fortran_reader): - ''' - Check that - - different argument types and - - non-existing optional argument at the end of the list - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_b: Routine = root_node.walk(Routine)[2] - assert routine_foo_b.name == "foo_b" - - call_foo_b: Call = routine_main.walk(Call)[2] - assert call_foo_b.routine.name == "foo" - - (result, arg_idx_list) = call_foo_b.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - - assert result is routine_foo_b - - -def test_call_get_callee_6_interfaces_1_1(fortran_reader): - ''' - Check that - - different argument types and - - existing optional argument at the end of the list - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_b: Routine = root_node.walk(Routine)[2] - assert routine_foo_b.name == "foo_b" - - call_foo_b: Call = routine_main.walk(Call)[3] - assert call_foo_b.routine.name == "foo" - - (result, arg_idx_list) = call_foo_b.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - assert arg_idx_list[2] == 2 - - assert result is routine_foo_b - - -def test_call_get_callee_6_interfaces_1_2(fortran_reader): - ''' - Check that - - different argument types and - - naming arguments resulting in a different order - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_b: Routine = root_node.walk(Routine)[2] - assert routine_foo_b.name == "foo_b" - - call_foo_b: Call = routine_main.walk(Call)[4] - assert call_foo_b.routine.name == "foo" - - (result, arg_idx_list) = call_foo_b.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 2 - assert arg_idx_list[2] == 1 - - assert result is routine_foo_b - - -def test_call_get_callee_6_interfaces_2_0(fortran_reader): - ''' - Check that - - different argument types (different order than in tests before) - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_c: Routine = root_node.walk(Routine)[3] - assert routine_foo_c.name == "foo_c" - - call_foo_c: Call = routine_main.walk(Call)[5] - assert call_foo_c.routine.name == "foo" - - (result, arg_idx_list) = call_foo_c.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - assert arg_idx_list[2] == 2 - - assert result is routine_foo_c - - -def test_call_get_callee_6_interfaces_2_1(fortran_reader): - ''' - Check that - - different argument types (different order than in tests before) and - - naming arguments resulting in a different order and - - optional argument - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_c: Routine = root_node.walk(Routine)[3] - assert routine_foo_c.name == "foo_c" - - call_foo_c: Call = routine_main.walk(Call)[6] - assert call_foo_c.routine.name == "foo" - - (result, arg_idx_list) = call_foo_c.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 1 - assert arg_idx_list[1] == 0 - - assert result is routine_foo_c - - -def test_call_get_callee_6_interfaces_2_2(fortran_reader): - ''' - Check that - - different argument types (different order than in tests before) and - - naming arguments resulting in a different order and - - last call argument without naming - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_c: Routine = root_node.walk(Routine)[3] - assert routine_foo_c.name == "foo_c" - - call_foo_c: Call = routine_main.walk(Call)[7] - assert call_foo_c.routine.name == "foo" - - (result, arg_idx_list) = call_foo_c.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 1 - assert arg_idx_list[1] == 0 - assert arg_idx_list[2] == 2 - - assert result is routine_foo_c - - -def test_call_get_callee_6_interfaces_3_0_mismatch(fortran_reader): - ''' - Check that matching a partial data type can also go wrong. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_optional: Routine = root_node.walk(Routine)[4] - assert routine_foo_optional.name == "foo_optional" - - call_foo_optional: Call = routine_main.walk(Call)[8] - assert call_foo_optional.routine.name == "foo" - - with pytest.raises(CallMatchingArgumentsNotFound) as einfo: - call_foo_optional.get_callee() - - assert "Argument partial type mismatch of call argument" in ( - str(einfo.value)) - - -def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): - ''' - Trigger error that matching arguments were not found - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - ! Use named argument 'd', which doesn't exist - ! to trigger an error when searching for the matching routine. - call foo(e, f, d=g) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b, c - end subroutine - -end module some_mod''' - - psyir = fortran_reader.psyir_from_source(code) - - routine_main: Routine = psyir.walk(Routine)[0] - assert routine_main.name == "main" - - call_foo: Call = routine_main.walk(Call)[0] - - with pytest.raises(CallMatchingArgumentsNotFound) as err: - call_foo.get_callee() - - assert "No matching routine found for 'call foo(e, f, d=g)" in str( - err.value - ) - - -def test_call_get_callee_8_arguments_not_handled(fortran_reader): +def test_call_get_callee_arguments_not_handled(fortran_reader): ''' Trigger error that matching arguments were not found. In this test, this is caused by omitting the required third non-optional @@ -1260,415 +647,15 @@ def test_call_get_callee_8_arguments_not_handled(fortran_reader): call_foo: Call = routine_main.walk(Call)[0] - with pytest.raises(CallMatchingArgumentsNotFound) as err: + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: call_foo.get_callee() - assert "No matching routine found for 'call foo(e, f)" in str(err.value) + assert ("CallMatchingArgumentsNotFound: Found routines, but" + " no routine with matching arguments found for 'call" + " foo(e, f)':" in str(err.value)) - -@pytest.mark.usefixtures("clear_module_manager_instance") -def test_call_get_callees_unresolved(fortran_reader, tmpdir, monkeypatch): - ''' - Test that get_callees() raises the expected error if the called routine - is unresolved. - ''' - code = ''' -subroutine top() - call bottom() -end subroutine top''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and there are " - "no wildcard imports." in str(err.value)) - # Repeat but in the presence of a wildcard import. - code = ''' -subroutine top() - use some_mod_somewhere - call bottom() -end subroutine top''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and attempted " - "to resolve the wildcard imports from ['some_mod_somewhere']. " - "However, failed to find the source for ['some_mod_somewhere']. " - "The module search path is set to []" in str(err.value)) - # Repeat but when some_mod_somewhere *is* resolved but doesn't help us - # find the routine we're looking for. - mod_manager = ModuleManager.get() - monkeypatch.setattr(mod_manager, "_instance", None) - path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) - with open(os.path.join(path, "some_mod_somewhere.f90"), "w", - encoding="utf-8") as ofile: - ofile.write('''\ -module some_mod_somewhere -end module some_mod_somewhere -''') - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and wildcard " - "imports from ['some_mod_somewhere']." in str(err.value)) - mod_manager = ModuleManager.get() - monkeypatch.setattr(mod_manager, "_instance", None) - code = ''' -subroutine top() - use another_mod, only: this_one - call this_one() -end subroutine top''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("RoutineSymbol 'this_one' is imported from Container 'another_mod'" - " but the source defining that container could not be found. The " - "module search path is set to [" in str(err.value)) - - -def test_call_get_callees_interface(fortran_reader): - ''' - Check that get_callees() works correctly when the target of a call is - actually a generic interface. - ''' - code = ''' -module my_mod - - interface bottom - module procedure :: rbottom, ibottom - end interface bottom -contains - subroutine top() - integer :: luggage - luggage = 0 - call bottom(luggage) - end subroutine top - - subroutine ibottom(luggage) - integer :: luggage - luggage = luggage + 1 - end subroutine ibottom - - subroutine rbottom(luggage) - real :: luggage - luggage = luggage + 1.0 - end subroutine rbottom -end module my_mod -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - callees = call.get_callees() - assert len(callees) == 2 - assert isinstance(callees[0], Routine) - assert callees[0].name == "rbottom" - assert isinstance(callees[1], Routine) - assert callees[1].name == "ibottom" - - -def test_call_get_callees_unsupported_type(fortran_reader): - ''' - Check that get_callees() raises the expected error when the called routine - is of UnsupportedFortranType. This is hard to achieve so we have to - manually construct some aspects of the test case. - - ''' - code = ''' -module my_mod - integer, target :: value -contains - subroutine top() - integer :: luggage - luggage = bottom() - end subroutine top - function bottom() result(fval) - integer, pointer :: fval - fval => value - end function bottom -end module my_mod -''' - psyir = fortran_reader.psyir_from_source(code) - container = psyir.children[0] - routine = container.find_routine_psyir("bottom") - rsym = container.symbol_table.lookup(routine.name) - # Ensure the type of this RoutineSymbol is UnsupportedFortranType. - rsym.datatype = UnsupportedFortranType("integer, pointer :: fval") - assign = container.walk(Assignment)[0] - # Currently `bottom()` gets matched by fparser2 as a structure constructor - # and the fparser2 frontend leaves this as a CodeBlock (TODO #2429) so - # replace it with a Call. Once #2429 is fixed the next two lines can be - # removed. - assert isinstance(assign.rhs, CodeBlock) - assign.rhs.replace_with(Call.create(rsym)) - call = psyir.walk(Call)[0] - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("RoutineSymbol 'bottom' exists in Container 'my_mod' but is of " - "UnsupportedFortranType" in str(err.value)) - - -def test_call_get_callees_file_container(fortran_reader): - ''' - Check that get_callees works if the called routine happens to be in file - scope, even when there's no Container. - ''' - code = ''' - subroutine top() - integer :: luggage - luggage = 0 - call bottom(luggage) - end subroutine top - - subroutine bottom(luggage) - integer :: luggage - luggage = luggage + 1 - end subroutine bottom -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - result = call.get_callees() - assert len(result) == 1 - assert isinstance(result[0], Routine) - assert result[0].name == "bottom" - - -def test_call_get_callees_no_container(fortran_reader): - ''' - Check that get_callees() raises the expected error when the Call is not - within a Container and the target routine cannot be found. - ''' - # To avoid having the routine symbol immediately dismissed as - # unresolved, the code that we initially process *does* have a Container. - code = ''' -module my_mod - -contains - subroutine top() - integer :: luggage - luggage = 0 - call bottom(luggage) - end subroutine top - - subroutine bottom(luggage) - integer :: luggage - luggage = luggage + 1 - end subroutine bottom -end module my_mod -''' - psyir = fortran_reader.psyir_from_source(code) - top_routine = psyir.walk(Routine)[0] - # Deliberately make the Routine node an orphan so there's no Container. - top_routine.detach() - call = top_routine.walk(Call)[0] - with pytest.raises(SymbolError) as err: - _ = call.get_callees() - assert ("Failed to find a Routine named 'bottom' in code:\n'subroutine " - "top()" in str(err.value)) - - -def test_call_get_callees_wildcard_import_local_container(fortran_reader): - ''' - Check that get_callees() works successfully for a routine accessed via - a wildcard import from another module in the same file. - ''' - code = ''' -module some_mod -contains - subroutine just_do_it() - write(*,*) "hello" - end subroutine just_do_it -end module some_mod -module other_mod - use some_mod -contains - subroutine run_it() - call just_do_it() - end subroutine run_it -end module other_mod -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - routines = call.get_callees() - assert len(routines) == 1 - assert isinstance(routines[0], Routine) - assert routines[0].name == "just_do_it" - - -def test_call_get_callees_import_local_container(fortran_reader): - ''' - Check that get_callees() works successfully for a routine accessed via - a specific import from another module in the same file. - ''' - code = ''' -module some_mod -contains - subroutine just_do_it() - write(*,*) "hello" - end subroutine just_do_it -end module some_mod -module other_mod - use some_mod, only: just_do_it -contains - subroutine run_it() - call just_do_it() - end subroutine run_it -end module other_mod -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - routines = call.get_callees() - assert len(routines) == 1 - assert isinstance(routines[0], Routine) - assert routines[0].name == "just_do_it" - - -@pytest.mark.usefixtures("clear_module_manager_instance") -def test_call_get_callees_wildcard_import_container(fortran_reader, - tmpdir, monkeypatch): - ''' - Check that get_callees() works successfully for a routine accessed via - a wildcard import from a module in another file. - ''' - code = ''' -module other_mod - use some_mod -contains - subroutine run_it() - call just_do_it() - end subroutine run_it -end module other_mod -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - # This should fail as it can't find the module. - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine " - "'just_do_it' - looked at any routines in the same source file" - in str(err.value)) - # Create the module containing the subroutine definition, - # write it to file and set the search path so that PSyclone can find it. - path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) - - with open(os.path.join(path, "some_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ -module some_mod -contains - subroutine just_do_it() - write(*,*) "hello" - end subroutine just_do_it -end module some_mod''') - routines = call.get_callees() - assert len(routines) == 1 - assert isinstance(routines[0], Routine) - assert routines[0].name == "just_do_it" - - -def test_fn_call_get_callees(fortran_reader): - ''' - Test that get_callees() works for a function call. - ''' - code = ''' -module some_mod - implicit none - integer :: luggage -contains - subroutine top() - luggage = 0 - luggage = luggage + my_func(1) - end subroutine top - - function my_func(val) - integer, intent(in) :: val - integer :: my_func - my_func = 1 + val - end function my_func -end module some_mod''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - result = call.get_callees() - assert result == [psyir.walk(Routine)[1]] - - -def test_get_callees_code_block(fortran_reader): - '''Test that get_callees() raises the expected error when the called - routine is in a CodeBlock.''' - code = ''' -module some_mod - implicit none - integer :: luggage -contains - subroutine top() - luggage = 0 - luggage = luggage + real(my_func(1)) - end subroutine top - - complex function my_func(val) - integer, intent(in) :: val - my_func = CMPLX(1 + val, 1.0) - end function my_func -end module some_mod''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[1] - with pytest.raises(SymbolError) as err: - _ = call.get_callees() - assert ("Failed to find a Routine named 'my_func' in Container " - "'some_mod'" in str(err.value)) - - -@pytest.mark.usefixtures("clear_module_manager_instance") -def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): - ''' - Test that get_callees() follows imports to find the definition of the - called routine. - ''' - code = ''' -module some_mod - use other_mod, only: pack_it - implicit none -contains - subroutine top() - integer :: luggage = 0 - call pack_it(luggage) - end subroutine top -end module some_mod''' - # Create the module containing an import of the subroutine definition, - # write it to file and set the search path so that PSyclone can find it. - path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) - - with open(os.path.join(path, "other_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ - module other_mod - use another_mod, only: pack_it - contains - end module other_mod - ''') - # Finally, create the module containing the routine definition. - with open(os.path.join(path, "another_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ - module another_mod - contains - subroutine pack_it(arg) - integer, intent(inout) :: arg - arg = arg + 2 - end subroutine pack_it - end module another_mod - ''') - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - result = call.get_callees() - assert len(result) == 1 - assert isinstance(result[0], Routine) - assert result[0].name == "pack_it" + assert ("CallMatchingArgumentsNotFound: Argument 'c' in" + " subroutine 'foo' not handled." in str(err.value)) @pytest.mark.usefixtures("clear_module_manager_instance") diff --git a/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py new file mode 100644 index 0000000000..7173622194 --- /dev/null +++ b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py @@ -0,0 +1,1425 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2020-2025, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# This file is based on gathering various components related to +# calls and routines from across psyclone. Hence, there's no clear author. +# Authors of gathered files: R. W. Ford, A. R. Porter and +# S. Siso, STFC Daresbury Lab +# Author: M. Schreiber, Univ. Grenoble Alpes / LJK / Inria +# ----------------------------------------------------------------------------- + + +import os +import pytest +from psyclone.configuration import Config +from psyclone.parse import ModuleManager +from psyclone.psyir.tools.call_routine_matcher import ( + CallRoutineMatcher, + CallMatchingArgumentsNotFoundError, +) +from psyclone.psyir.symbols import UnsupportedFortranType, SymbolError +from psyclone.psyir.nodes import Call, Node, Routine, Assignment, CodeBlock +from psyclone.psyir.transformations import InlineTrans +from psyclone.tests.utilities import Compile + + +def test_apply_optional_and_named_arg(fortran_reader): + """Test that the validate method inlines a routine + that has an optional argument.""" + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, named=1.0)\n" + " ! Result:\n" + " ! var = var + 1.0 + 1.0\n" + " call sub(var, 2.0, named=1.0)\n" + " ! Result:\n" + " ! var = var + 2.0\n" + " ! var = var + 1.0 + 1.0\n" + "end subroutine main\n" + "subroutine sub(x, opt, named)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " real :: named\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" + " x = x + 1.0 + named\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir: Node = fortran_reader.psyir_from_source(code) + + inline_trans = InlineTrans() + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + for call in psyir.walk(Call, stop_type=Call): + call: Call + if call.routine.name != "sub": + continue + + inline_trans.apply(call) + + assert ( + """var = var + 1.0 + 1.0 + var = var + 2.0 + var = var + 1.0 + 1.0""" + in routine_main.debug_string() + ) + + +def test_unresolved_types(fortran_reader): + """Test that the validate method inlines a routine that has a named + argument.""" + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, opt=1.0)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real :: opt\n" + " x = x + 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + + crm = CallRoutineMatcher(call) + + crm.set_option(ignore_unresolved_types=True) + crm.get_callee_candidates() + + +def test_call_get_callee_1_simple_match(fortran_reader): + """ + Check that the right routine has been found for a single routine + implementation. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(e, f, g) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b, c + end subroutine + +end module some_mod""" + + psyir = fortran_reader.psyir_from_source(code) + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + + (result, _) = call_foo.get_callee() + + routine_match: Routine = psyir.walk(Routine)[1] + assert result is routine_match + + +def test_call_get_callee_2_optional_args(fortran_reader): + """ + Check that optional arguments have been correlated correctly. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(e, f) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b + integer, optional :: c + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_match: Routine = root_node.walk(Routine)[1] + assert routine_match.name == "foo" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + (result, arg_idx_list) = call_foo.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + + assert result is routine_match + + +def test_call_get_callee_3a_trigger_error(fortran_reader): + """ + Test which is supposed to trigger an error when no matching routine + is found + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(e, f, g) + end subroutine + + ! Matching routine + subroutine foo(a, b) + integer :: a, b + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: + call_foo.get_callee() + + assert ( + "Found routines, but no routine with matching arguments found" + in str(err.value) + ) + + +def test_call_get_callee_3c_trigger_error(fortran_reader): + """ + Test which is supposed to trigger an error when no matching routine + is found, but we use the special option check_matching_arguments=False + to find one. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(e, f, g) + end subroutine + + ! Matching routine + subroutine foo(a, b) + integer :: a, b + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + call_foo.get_callee(check_matching_arguments=False) + + +def test_call_get_callee_4_named_arguments(fortran_reader): + """ + Check that named arguments have been correlated correctly + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(c=e, a=f, b=g) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b, c + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_match: Routine = root_node.walk(Routine)[1] + assert routine_match.name == "foo" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + (result, arg_idx_list) = call_foo.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 2 + assert arg_idx_list[1] == 0 + assert arg_idx_list[2] == 1 + + assert result is routine_match + + +def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): + """ + Check that optional and named arguments have been correlated correctly + when the call is to a generic interface. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(b=e, a=f) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b + integer, optional :: c + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_match: Routine = root_node.walk(Routine)[1] + assert routine_match.name == "foo" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + (result, arg_idx_list) = call_foo.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 1 + assert arg_idx_list[1] == 0 + + assert result is routine_match + + +_code_test_get_callee_6 = """ +module some_mod + implicit none + + interface foo + procedure foo_a, foo_b, foo_c, foo_optional + end interface +contains + + subroutine main() + integer :: e_int, f_int, g_int + real :: e_real, f_real, g_real + + ! Should match foo_a, test_call_get_callee_6_interfaces_0_0 + call foo(e_int, f_int) + + ! Should match foo_a, test_call_get_callee_6_interfaces_0_1 + call foo(e_int, f_int, g_int) + + ! Should match foo_b, test_call_get_callee_6_interfaces_1_0 + call foo(e_real, f_int) + + ! Should match foo_b, test_call_get_callee_6_interfaces_1_1 + call foo(e_real, f_int, g_int) + + ! Should match foo_b, test_call_get_callee_6_interfaces_1_2 + call foo(e_real, c=f_int, b=g_int) + + ! Should match foo_c, test_call_get_callee_6_interfaces_2_0 + call foo(e_int, f_real, g_int) + + ! Should match foo_c, test_call_get_callee_6_interfaces_2_1 + call foo(b=e_real, a=f_int) + + ! Should match foo_c, test_call_get_callee_6_interfaces_2_2 + call foo(b=e_real, a=f_int, g_int) + + ! Should not match foo_optional because of invalid type, + ! test_call_get_callee_6_interfaces_3_0_mismatch + call foo(f_int, e_real, g_int, g_int) + end subroutine + + subroutine foo_a(a, b, c) + integer :: a, b + integer, optional :: c + end subroutine + + subroutine foo_b(a, b, c) + real :: a + integer :: b + integer, optional :: c + end subroutine + + subroutine foo_c(a, b, c) + integer :: a + real :: b + integer, optional :: c + end subroutine + + subroutine foo_optional(a, b, c, d) + integer :: a + real :: b + integer :: c + real, optional :: d ! real vs. int + end subroutine + + +end module some_mod""" + + +def test_call_get_callee_6_interfaces_0_0(fortran_reader): + """ + Check that a non-existing optional argument at the end of the list + has been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_a: Routine = root_node.walk(Routine)[1] + assert routine_foo_a.name == "foo_a" + + call_foo_a: Call = routine_main.walk(Call)[0] + assert call_foo_a.routine.name == "foo" + + (result, arg_idx_list) = call_foo_a.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + + assert result is routine_foo_a + + +def test_call_get_callee_6_interfaces_0_1(fortran_reader): + """ + Check that an existing optional argument at the end of the list + has been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_a: Routine = root_node.walk(Routine)[1] + assert routine_foo_a.name == "foo_a" + + call_foo_a: Call = routine_main.walk(Call)[1] + assert call_foo_a.routine.name == "foo" + + (result, arg_idx_list) = call_foo_a.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + assert arg_idx_list[2] == 2 + + assert result is routine_foo_a + + +def test_call_get_callee_6_interfaces_1_0(fortran_reader): + """ + Check that + - different argument types and + - non-existing optional argument at the end of the list + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_b: Routine = root_node.walk(Routine)[2] + assert routine_foo_b.name == "foo_b" + + call_foo_b: Call = routine_main.walk(Call)[2] + assert call_foo_b.routine.name == "foo" + + (result, arg_idx_list) = call_foo_b.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + + assert result is routine_foo_b + + +def test_call_get_callee_6_interfaces_1_1(fortran_reader): + """ + Check that + - different argument types and + - existing optional argument at the end of the list + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_b: Routine = root_node.walk(Routine)[2] + assert routine_foo_b.name == "foo_b" + + call_foo_b: Call = routine_main.walk(Call)[3] + assert call_foo_b.routine.name == "foo" + + (result, arg_idx_list) = call_foo_b.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + assert arg_idx_list[2] == 2 + + assert result is routine_foo_b + + +def test_call_get_callee_6_interfaces_1_2(fortran_reader): + """ + Check that + - different argument types and + - naming arguments resulting in a different order + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_b: Routine = root_node.walk(Routine)[2] + assert routine_foo_b.name == "foo_b" + + call_foo_b: Call = routine_main.walk(Call)[4] + assert call_foo_b.routine.name == "foo" + + (result, arg_idx_list) = call_foo_b.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 2 + assert arg_idx_list[2] == 1 + + assert result is routine_foo_b + + +def test_call_get_callee_6_interfaces_2_0(fortran_reader): + """ + Check that + - different argument types (different order than in tests before) + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_c: Routine = root_node.walk(Routine)[3] + assert routine_foo_c.name == "foo_c" + + call_foo_c: Call = routine_main.walk(Call)[5] + assert call_foo_c.routine.name == "foo" + + (result, arg_idx_list) = call_foo_c.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + assert arg_idx_list[2] == 2 + + assert result is routine_foo_c + + +def test_call_get_callee_6_interfaces_2_1(fortran_reader): + """ + Check that + - different argument types (different order than in tests before) and + - naming arguments resulting in a different order and + - optional argument + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_c: Routine = root_node.walk(Routine)[3] + assert routine_foo_c.name == "foo_c" + + call_foo_c: Call = routine_main.walk(Call)[6] + assert call_foo_c.routine.name == "foo" + + (result, arg_idx_list) = call_foo_c.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 1 + assert arg_idx_list[1] == 0 + + assert result is routine_foo_c + + +def test_call_get_callee_6_interfaces_2_2(fortran_reader): + """ + Check that + - different argument types (different order than in tests before) and + - naming arguments resulting in a different order and + - last call argument without naming + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_c: Routine = root_node.walk(Routine)[3] + assert routine_foo_c.name == "foo_c" + + call_foo_c: Call = routine_main.walk(Call)[7] + assert call_foo_c.routine.name == "foo" + + (result, arg_idx_list) = call_foo_c.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 1 + assert arg_idx_list[1] == 0 + assert arg_idx_list[2] == 2 + + assert result is routine_foo_c + + +def test_call_get_callee_6_interfaces_3_0_mismatch(fortran_reader): + """ + Check that matching a partial data type can also go wrong. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_optional: Routine = root_node.walk(Routine)[4] + assert routine_foo_optional.name == "foo_optional" + + call_foo_optional: Call = routine_main.walk(Call)[8] + assert call_foo_optional.routine.name == "foo" + + with pytest.raises(CallMatchingArgumentsNotFoundError) as einfo: + call_foo_optional.get_callee() + + assert "Argument partial type mismatch of call argument" in ( + str(einfo.value) + ) + + +def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): + """ + Trigger error that matching arguments were not found + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + ! Use named argument 'd', which doesn't exist + ! to trigger an error when searching for the matching routine. + call foo(e, f, d=g) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b, c + end subroutine + +end module some_mod""" + + psyir = fortran_reader.psyir_from_source(code) + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: + call_foo.get_callee() + + assert ( + "Found routines, but no routine with matching arguments" + " found for 'call foo(e, f, d=g)':" in str(err.value) + ) + + print(str(err.value)) + assert ( + "CallMatchingArgumentsNotFound: Named argument" + " 'd' not found." in str(err.value) + ) + + +def test_set_routine(fortran_reader): + """Test the routine setter (not in the constructor).""" + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, opt=1.0)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real :: opt\n" + " x = x + 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + routine = psyir.walk(Routine)[0] + + crm = CallRoutineMatcher() + crm.set_call_node(call) + crm.set_routine_node(routine) + + +def test_fn_call_get_callees(fortran_reader): + """ + Test that get_callees() works for a function call. + """ + code = """ +module some_mod + implicit none + integer :: luggage +contains + subroutine top() + luggage = 0 + luggage = luggage + my_func(1) + end subroutine top + + function my_func(val) + integer, intent(in) :: val + integer :: my_func + my_func = 1 + val + end function my_func +end module some_mod""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + result = call.get_callees() + assert result == [psyir.walk(Routine)[1]] + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_call_get_callees_wildcard_import_container( + fortran_reader, tmpdir, monkeypatch +): + """ + Check that get_callees() works successfully for a routine accessed via + a wildcard import from a module in another file. + """ + code = """ +module other_mod + use some_mod +contains + subroutine run_it() + call just_do_it() + end subroutine run_it +end module other_mod +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + # This should fail as it can't find the module. + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "Failed to find the source code of the unresolved routine " + "'just_do_it' - looked at any routines in the same source file" + in str(err.value) + ) + # Create the module containing the subroutine definition, + # write it to file and set the search path so that PSyclone can find it. + path = str(tmpdir) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + + with open( + os.path.join(path, "some_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ +module some_mod +contains + subroutine just_do_it() + write(*,*) "hello" + end subroutine just_do_it +end module some_mod""" + ) + routines = call.get_callees() + assert len(routines) == 1 + assert isinstance(routines[0], Routine) + assert routines[0].name == "just_do_it" + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_call_get_callees_unresolved(fortran_reader, tmpdir, monkeypatch): + """ + Test that get_callees() raises the expected error if the called routine + is unresolved. + """ + code = """ +subroutine top() + call bottom() +end subroutine top""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and there are " + "no wildcard imports." in str(err.value) + ) + # Repeat but in the presence of a wildcard import. + code = """ +subroutine top() + use some_mod_somewhere + call bottom() +end subroutine top""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and attempted " + "to resolve the wildcard imports from ['some_mod_somewhere']. " + "However, failed to find the source for ['some_mod_somewhere']. " + "The module search path is set to []" in str(err.value) + ) + # Repeat but when some_mod_somewhere *is* resolved but doesn't help us + # find the routine we're looking for. + mod_manager = ModuleManager.get() + monkeypatch.setattr(mod_manager, "_instance", None) + path = str(tmpdir) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + with open( + os.path.join(path, "some_mod_somewhere.f90"), "w", encoding="utf-8" + ) as ofile: + ofile.write( + """\ +module some_mod_somewhere +end module some_mod_somewhere +""" + ) + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and wildcard " + "imports from ['some_mod_somewhere']." in str(err.value) + ) + mod_manager = ModuleManager.get() + monkeypatch.setattr(mod_manager, "_instance", None) + code = """ +subroutine top() + use another_mod, only: this_one + call this_one() +end subroutine top""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "RoutineSymbol 'this_one' is imported from Container 'another_mod'" + " but the source defining that container could not be found. The " + "module search path is set to [" in str(err.value) + ) + + +def test_call_get_callees_interface(fortran_reader): + """ + Check that get_callees() works correctly when the target of a call is + actually a generic interface. + """ + code = """ +module my_mod + + interface bottom + module procedure :: rbottom, ibottom + end interface bottom +contains + subroutine top() + integer :: luggage + luggage = 0 + call bottom(luggage) + end subroutine top + + subroutine ibottom(luggage) + integer :: luggage + luggage = luggage + 1 + end subroutine ibottom + + subroutine rbottom(luggage) + real :: luggage + luggage = luggage + 1.0 + end subroutine rbottom +end module my_mod +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + callees = call.get_callees() + assert len(callees) == 2 + assert isinstance(callees[0], Routine) + assert callees[0].name == "rbottom" + assert isinstance(callees[1], Routine) + assert callees[1].name == "ibottom" + + +def test_call_get_callees_unsupported_type(fortran_reader): + """ + Check that get_callees() raises the expected error when the called routine + is of UnsupportedFortranType. This is hard to achieve so we have to + manually construct some aspects of the test case. + + """ + code = """ +module my_mod + integer, target :: value +contains + subroutine top() + integer :: luggage + luggage = bottom() + end subroutine top + function bottom() result(fval) + integer, pointer :: fval + fval => value + end function bottom +end module my_mod +""" + psyir = fortran_reader.psyir_from_source(code) + container = psyir.children[0] + routine = container.find_routine_psyir("bottom") + rsym = container.symbol_table.lookup(routine.name) + # Ensure the type of this RoutineSymbol is UnsupportedFortranType. + rsym.datatype = UnsupportedFortranType("integer, pointer :: fval") + assign = container.walk(Assignment)[0] + # Currently `bottom()` gets matched by fparser2 as a structure constructor + # and the fparser2 frontend leaves this as a CodeBlock (TODO #2429) so + # replace it with a Call. Once #2429 is fixed the next two lines can be + # removed. + assert isinstance(assign.rhs, CodeBlock) + assign.rhs.replace_with(Call.create(rsym)) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "RoutineSymbol 'bottom' exists in Container 'my_mod' but is of " + "UnsupportedFortranType" in str(err.value) + ) + + +def test_call_get_callees_local(fortran_reader): + """ + Check that get_callees() works as expected when the target of the Call + exists in the same Container as the call site. + """ + code = """ +module some_mod + implicit none + integer :: luggage +contains + subroutine top() + luggage = 0 + call bottom() + end subroutine top + + subroutine bottom() + luggage = luggage + 1 + end subroutine bottom +end module some_mod""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + result = call.get_callees() + assert result == [psyir.walk(Routine)[1]] + + +def test_call_get_callee_matching_arguments_not_found(fortran_reader): + """ + Trigger error that matching arguments were not found. + In this test, this is caused by omitting the required third non-optional + argument. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f + ! Omit the 3rd required argument + call foo(e, f) + end subroutine + + ! Routine matching by 'name', but not by argument matching + subroutine foo(a, b, c) + integer :: a, b, c + end subroutine + +end module some_mod""" + + psyir = fortran_reader.psyir_from_source(code) + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: + call_foo.get_callee() + + assert ( + "CallMatchingArgumentsNotFound: Found routines, but" + " no routine with matching arguments found for 'call" + " foo(e, f)':" in str(err.value) + ) + + assert ( + "CallMatchingArgumentsNotFound: Argument 'c' in" + " subroutine 'foo' not handled." in str(err.value) + ) + + +def test_call_get_callees_file_container(fortran_reader): + """ + Check that get_callees works if the called routine happens to be in file + scope, even when there's no Container. + """ + code = """ + subroutine top() + integer :: luggage + luggage = 0 + call bottom(luggage) + end subroutine top + + subroutine bottom(luggage) + integer :: luggage + luggage = luggage + 1 + end subroutine bottom +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + result = call.get_callees() + assert len(result) == 1 + assert isinstance(result[0], Routine) + assert result[0].name == "bottom" + + +def test_call_get_callees_no_container(fortran_reader): + """ + Check that get_callees() raises the expected error when the Call is not + within a Container and the target routine cannot be found. + """ + # To avoid having the routine symbol immediately dismissed as + # unresolved, the code that we initially process *does* have a Container. + code = """ +module my_mod + +contains + subroutine top() + integer :: luggage + luggage = 0 + call bottom(luggage) + end subroutine top + + subroutine bottom(luggage) + integer :: luggage + luggage = luggage + 1 + end subroutine bottom +end module my_mod +""" + psyir = fortran_reader.psyir_from_source(code) + top_routine = psyir.walk(Routine)[0] + # Deliberately make the Routine node an orphan so there's no Container. + top_routine.detach() + call = top_routine.walk(Call)[0] + with pytest.raises(SymbolError) as err: + _ = call.get_callees() + assert ( + "Failed to find a Routine named 'bottom' in code:\n'subroutine " + "top()" in str(err.value) + ) + + +def test_call_get_callees_wildcard_import_local_container(fortran_reader): + """ + Check that get_callees() works successfully for a routine accessed via + a wildcard import from another module in the same file. + """ + code = """ +module some_mod +contains + subroutine just_do_it() + write(*,*) "hello" + end subroutine just_do_it +end module some_mod +module other_mod + use some_mod +contains + subroutine run_it() + call just_do_it() + end subroutine run_it +end module other_mod +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + routines = call.get_callees() + assert len(routines) == 1 + assert isinstance(routines[0], Routine) + assert routines[0].name == "just_do_it" + + +def test_call_get_callees_import_local_container(fortran_reader): + """ + Check that get_callees() works successfully for a routine accessed via + a specific import from another module in the same file. + """ + code = """ +module some_mod +contains + subroutine just_do_it() + write(*,*) "hello" + end subroutine just_do_it +end module some_mod +module other_mod + use some_mod, only: just_do_it +contains + subroutine run_it() + call just_do_it() + end subroutine run_it +end module other_mod +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + routines = call.get_callees() + assert len(routines) == 1 + assert isinstance(routines[0], Routine) + assert routines[0].name == "just_do_it" + + +def test_get_callees_code_block(fortran_reader): + """Test that get_callees() raises the expected error when the called + routine is in a CodeBlock.""" + code = """ +module some_mod + implicit none + integer :: luggage +contains + subroutine top() + luggage = 0 + luggage = luggage + real(my_func(1)) + end subroutine top + + complex function my_func(val) + integer, intent(in) :: val + my_func = CMPLX(1 + val, 1.0) + end function my_func +end module some_mod""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[1] + with pytest.raises(SymbolError) as err: + _ = call.get_callees() + assert ( + "Failed to find a Routine named 'my_func' in Container " + "'some_mod'" in str(err.value) + ) + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): + """ + Test that get_callees() follows imports to find the definition of the + called routine. + """ + code = """ +module some_mod + use other_mod, only: pack_it + implicit none +contains + subroutine top() + integer :: luggage = 0 + call pack_it(luggage) + end subroutine top +end module some_mod""" + # Create the module containing an import of the subroutine definition, + # write it to file and set the search path so that PSyclone can find it. + path = str(tmpdir) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + + with open( + os.path.join(path, "other_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ + module other_mod + use another_mod, only: pack_it + contains + end module other_mod + """ + ) + # Finally, create the module containing the routine definition. + with open( + os.path.join(path, "another_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ + module another_mod + contains + subroutine pack_it(arg) + integer, intent(inout) :: arg + arg = arg + 2 + end subroutine pack_it + end module another_mod + """ + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + result = call.get_callees() + assert len(result) == 1 + assert isinstance(result[0], Routine) + assert result[0].name == "pack_it" + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): + """ + Test that get_callees() raises the expected error if a module from which + a routine is imported has a private shadow of that routine (and thus we + don't know where to look for the target routine). + """ + code = """ +module some_mod + use other_mod, only: pack_it + implicit none +contains + subroutine top() + integer :: luggage = 0 + call pack_it(luggage) + end subroutine top +end module some_mod""" + # Create the module containing a private routine with the name we are + # searching for, write it to file and set the search path so that PSyclone + # can find it. + path = str(tmpdir) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + + with open( + os.path.join(path, "other_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ + module other_mod + use another_mod + private pack_it + contains + function pack_it(arg) + integer :: arg + integer :: pack_it + end function pack_it + end module other_mod + """ + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "RoutineSymbol 'pack_it' is imported from Container 'other_mod' " + "but that Container defines a private Symbol of the same name. " + "Searching for the Container that defines a public Routine with " + "that name is not yet supported - TODO #924" in str(err.value) + ) + + +def test_apply_empty_routine_coverage_option_check_strict_array_datatype( + fortran_reader, fortran_writer, tmpdir +): + """For coverage of particular branch in `inline_trans.py`.""" + code = ( + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer, dimension(6) :: i\n" + " i = 10\n" + " call sub(i)\n" + " end subroutine run_it\n" + " subroutine sub(idx)\n" + " integer, dimension(:) :: idx\n" + " end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.set_option(check_argument_strict_array_datatype=False) + inline_trans.apply(routine) + output = fortran_writer(psyir) + assert " i = 10\n\n" " end subroutine run_it\n" in output + assert Compile(tmpdir).string_compiles(output) + + +def test_apply_array_access_check_unresolved_symbols_error( + fortran_reader, fortran_writer, tmpdir +): + """ + This check solely exists for the coverage report to + catch the simple case `if not check_unresolved_symbols:` + in `symbol_table.py` + + """ + code = ( + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " real :: a(10)\n" + " do i=1,10\n" + " call sub(a, i)\n" + " end do\n" + " end subroutine run_it\n" + " subroutine sub(x, ivar)\n" + " real, intent(inout), dimension(10) :: x\n" + " integer, intent(in) :: ivar\n" + " integer :: i\n" + " do i = 1, 10\n" + " x(i) = 2.0*ivar\n" + " end do\n" + " end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.set_option(check_unresolved_symbols=False) + inline_trans.apply(routine) + output = fortran_writer(psyir) + assert ( + " do i = 1, 10, 1\n" + " do i_1 = 1, 10, 1\n" + " a(i_1) = 2.0 * i\n" + " enddo\n" in output + ) + assert Compile(tmpdir).string_compiles(output) + + +def test_apply_array_access_check_unresolved_override_option( + fortran_reader, fortran_writer, tmpdir +): + """ + This check solely exists for the coverage report to catch + the case where the override option to ignore unresolved + types is used. + + """ + code = ( + "module test_mod\n" + "use does_not_exist\n" + "contains\n" + " subroutine run_it()\n" + " type(unknown_type) :: a\n" + " call sub(a%unresolved_type)\n" + " end subroutine run_it\n" + " subroutine sub(a)\n" + " type(unresolved) :: a\n" + " end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.set_option( + check_argument_ignore_unresolved_types=True + ) + inline_trans.apply(call) diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 2e61e19530..2ad16f0c6c 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -34,14 +34,20 @@ # Author: A. R. Porter, STFC Daresbury Lab # Modified: R. W. Ford and S. Siso, STFC Daresbury Lab -'''This module tests the inlining transformation. -''' +'''This module tests the inlining transformation.''' import os import pytest from psyclone.configuration import Config -from psyclone.psyir.nodes import Call, IntrinsicCall, Reference, Routine, Loop +from psyclone.psyir.nodes import ( + Call, + IntrinsicCall, + Loop, + Node, + Reference, + Routine, +) from psyclone.psyir.symbols import ( AutomaticInterface, DataSymbol, UnresolvedType) from psyclone.psyir.transformations import ( @@ -101,6 +107,50 @@ def test_apply_empty_routine(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) +def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): + ''' + Check that the formal arguments to the inlined routine are not included + when checking for clashes (since they will be replaced by the actual + arguments to the call). + ''' + + code_clash = """ + subroutine sub(Istr) + integer :: Istr + real :: x + x = 2.0*x + call sub_sub(Istr) + end subroutine sub + + subroutine sub_sub(Istr) + integer :: i + integer :: Istr + real :: b(10) + + b(Istr:10) = 1.0 + end subroutine sub_sub""" + + psyir = fortran_reader.psyir_from_source(code_clash) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.apply(call) + expected = '''\ +subroutine sub(istr) + integer :: istr + real :: x + integer :: i + real, dimension(10) :: b + + x = 2.0 * x + b(istr:) = 1.0 + +end subroutine sub +''' + output = fortran_writer(psyir) + assert expected in output + assert Compile(tmpdir).string_compiles(output) + + def test_apply_single_return(fortran_reader, fortran_writer, tmpdir): '''Check that a call to a routine containing only a return statement is removed. ''' @@ -154,6 +204,44 @@ def test_apply_return_then_cb(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) +def test_apply_provided_routine(fortran_reader, fortran_writer, tmpdir): + ''' Check that the apply() method works also for a provided routine. ''' + code = ( + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " real :: a(10)\n" + " do i=1,10\n" + " a(i) = 1.0\n" + " call sub(a(i))\n" + " end do\n" + " end subroutine run_it\n" + " subroutine sub(x)\n" + " real, intent(inout) :: x\n" + " x = 2.0*x\n" + " end subroutine sub\n" + " subroutine sub2(x, y)\n" + " real, intent(inout) :: x, y\n" + " x = 2.0*x\n" + " end subroutine sub2\n" + "end module test_mod\n") + psyir = fortran_reader.psyir_from_source(code) + + call = psyir.walk(Call)[0] + + routine = psyir.walk(Routine)[1] + inline_trans = InlineTrans() + inline_trans.apply(call, routine) + + routine = psyir.walk(Routine)[2] + with pytest.raises(TransformationError) as einfo: + inline_trans.apply(call, routine) + + assert "Routine's argument(s) don't match:" in str(einfo.value) + assert "Argument 'y' in subroutine 'sub2' not handled" in str(einfo.value) + + def test_apply_array_arg(fortran_reader, fortran_writer, tmpdir): ''' Check that the apply() method works correctly for a very simple call to a routine with an array reference as argument. ''' @@ -262,6 +350,7 @@ def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): monkeypatch.setattr(Config.get(), '_include_paths', [str(src_dir)]) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) if ("actual argument 'cu_fld' corresponding to an array formal " @@ -319,6 +408,7 @@ def test_apply_struct_arg(fortran_reader, fortran_writer, tmpdir): f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) @@ -391,6 +481,7 @@ def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) calls = psyir.walk(Call) # First one should be fine. inline_trans.apply(calls[0]) @@ -453,6 +544,7 @@ def test_apply_struct_slice_arg(fortran_reader, fortran_writer, tmpdir): f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -490,6 +582,7 @@ def test_apply_struct_local_limits_caller(fortran_reader, fortran_writer, f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -534,6 +627,7 @@ def test_apply_struct_local_limits_caller_decln(fortran_reader, fortran_writer, f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -587,6 +681,7 @@ def test_apply_struct_local_limits_routine(fortran_reader, fortran_writer, f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -641,6 +736,7 @@ def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): ''' psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) acall = psyir.walk(Call, stop_type=Call)[0] inline_trans.apply(acall) output = fortran_writer(psyir) @@ -687,6 +783,7 @@ def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): if not isinstance(routine, IntrinsicCall): inline_trans.apply(routine) @@ -747,6 +844,7 @@ def test_apply_array_slice_arg(fortran_reader, fortran_writer, tmpdir): "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(call) output = fortran_writer(psyir) @@ -797,6 +895,7 @@ def test_apply_struct_array_arg(fortran_reader, fortran_writer, tmpdir): psyir = fortran_reader.psyir_from_source(code) loops = psyir.walk(Loop) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) inline_trans.apply(loops[0].loop_body.children[1]) inline_trans.apply(loops[1].loop_body.children[1]) inline_trans.apply(loops[2].loop_body.children[1]) @@ -852,6 +951,7 @@ def test_apply_struct_array_slice_arg(fortran_reader, fortran_writer, tmpdir): f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for call in psyir.walk(Call): if not isinstance(call, IntrinsicCall): if call.arguments[0].debug_string() == "grid%local%data": @@ -924,6 +1024,7 @@ def test_apply_struct_array(fortran_reader, fortran_writer, tmpdir, f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) if "use some_mod" in type_decln: with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) @@ -969,6 +1070,7 @@ def test_apply_repeated_module_use(fortran_reader, fortran_writer): "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(call) output = fortran_writer(psyir) @@ -1616,11 +1718,14 @@ def test_validate_calls_find_routine(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Cannot inline routine 'sub' because its source cannot be found: " - "Failed to find the source code of the unresolved routine 'sub' - " - "looked at any routines in the same source file and attempted to " - "resolve the wildcard imports from ['some_mod']. However, failed " - "to find the source for ['some_mod']" in str(err.value)) + assert ( + "Cannot inline routine 'sub' because its source cannot be found:\n" + "Failed to find the source code of the unresolved routine 'sub' - " + "looked at any routines in the same source file and attempted to " + "resolve the wildcard imports from ['some_mod']. However, failed " + "to find the source for ['some_mod']" + in str(err.value) + ) def test_validate_return_stmt(fortran_reader): @@ -1673,9 +1778,13 @@ def test_validate_codeblock(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Routine 'sub' contains one or more CodeBlocks and therefore " - "cannot be inlined. (If you are confident " in str(err.value)) - inline_trans.validate(call, options={"force": True}) + assert ( + "Routine 'sub' contains one or more CodeBlocks and therefore " + "cannot be inlined. (If you are confident " + in str(err.value) + ) + inline_trans.set_option(check_inline_codeblocks=False) + inline_trans.validate(call) def test_validate_unsupportedtype_argument(fortran_reader): @@ -1703,9 +1812,19 @@ def test_validate_unsupportedtype_argument(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) - assert ("Routine 'sub' cannot be inlined because it contains a Symbol 'x' " - "which is an Argument of UnsupportedType: 'REAL, POINTER, " - "INTENT(INOUT) :: x'" in str(err.value)) + + assert ( + "Transformation Error: Cannot inline routine 'sub'" + " because its source cannot be found:" + in str(err.value) + ) + assert ( + "Argument partial type mismatch of call argument" + " 'Reference[name:'ptr']' and routine argument 'x:" + " DataSymbol'" + in str(err.value) + ) def test_validate_unknowninterface(fortran_reader, fortran_writer, tmpdir): @@ -1731,9 +1850,11 @@ def test_validate_unknowninterface(fortran_reader, fortran_writer, tmpdir): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) - assert (" Routine 'sub' cannot be inlined because it contains a Symbol " - "'x' with an UnknownInterface: 'REAL, POINTER :: x'" - in str(err.value)) + assert ( + " Routine 'sub' cannot be inlined because it contains a Symbol " + "'x' with an UnknownInterface: 'REAL, POINTER :: x'" + in str(err.value) + ) # But if the interface is known, it has no problem inlining it xvar = psyir.walk(Routine)[1].symbol_table.lookup("x") @@ -1774,8 +1895,11 @@ def test_validate_static_var(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) - assert ("Routine 'sub' cannot be inlined because it has a static (Fortran " - "SAVE) interface for Symbol 'state'." in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it has a static (Fortran " + "SAVE) interface for Symbol 'state'." + in str(err.value) + ) @pytest.mark.parametrize("code_body", ["idx = idx + 5_i_def", @@ -1808,9 +1932,12 @@ def test_validate_unresolved_precision_sym(fortran_reader, code_body): var_name = "wp" else: var_name = "i_def" - assert (f"Routine 'sub' cannot be inlined because it accesses variable " - f"'{var_name}' and this cannot be found in any of the containers " - f"directly imported into its symbol table" in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it accesses variable " + f"'{var_name}' and this cannot be found in any of the containers " + "directly imported into its symbol table" + in str(err.value) + ) def test_validate_resolved_precision_sym(fortran_reader, monkeypatch, @@ -1937,9 +2064,18 @@ def test_validate_wrong_number_args(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Cannot inline 'call sub(i, trouble)' because the number of " - "arguments supplied to the call (2) does not match the number of " - "arguments the routine is declared to have (1)" in str(err.value)) + + assert ( + "Transformation Error: Cannot inline routine 'sub'" + " because its source cannot be found:\n" + "CallMatchingArgumentsNotFound: Found routines," + " but no routine with matching arguments found" + " for 'call sub(i, trouble)':\n" + "CallMatchingArgumentsNotFound: More arguments" + " in call ('call sub(i, trouble)') than callee" + " (routine 'sub')" + in str(err.value) + ) def test_validate_unresolved_import(fortran_reader): @@ -2023,6 +2159,7 @@ def test_validate_array_reshape(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ("Cannot inline routine 's' because it reshapes an argument: actual" @@ -2055,6 +2192,7 @@ def test_validate_array_arg_expression(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ("The call 'call s(a + b, 10)\n' cannot be inlined because actual " @@ -2081,6 +2219,7 @@ def test_validate_indirect_range(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ("Cannot inline routine 'sub' because argument 'var(indices(:))' " @@ -2105,6 +2244,7 @@ def test_validate_non_unit_stride_slice(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ("Cannot inline routine 'sub' because one of its arguments is an " @@ -2112,12 +2252,30 @@ def test_validate_non_unit_stride_slice(fortran_reader): str(err.value)) -def test_validate_named_arg(fortran_reader): - '''Test that the validate method rejects an attempt to inline a routine - that has a named argument.''' - # In reality, the routine with a named argument would almost certainly - # use the 'present' intrinsic but, since that gives a CodeBlock that itself - # prevents inlining, our test example omits it. +def test_set_options(fortran_reader): + '''Test that simply sets all options for sake of the coverage test.''' + + inline_trans = InlineTrans() + inline_trans.set_option( + ignore_missing_modules=False, + check_argument_strict_array_datatype=False, + check_argument_matching=False, + + check_inline_codeblocks=False, + check_diff_container_clashes=False, + check_diff_container_clashes_unres_types=False, + check_resolve_imports=False, + check_static_interface=False, + check_array_type=False, + check_unsupported_type=False, + check_unresolved_symbols=False, + ) + + +def test_apply_named_arg(fortran_reader): + '''Test that the validate method inlines a routine that has a named + argument.''' + code = ( "module test_mod\n" "contains\n" @@ -2127,10 +2285,35 @@ def test_validate_named_arg(fortran_reader): "end subroutine main\n" "subroutine sub(x, opt)\n" " real, intent(inout) :: x\n" + " real :: opt\n" + " x = x + 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + + inline_trans.apply(call) + + +def test_apply_optional_arg(fortran_reader): + '''Test that the validate method inlines a routine + that has an optional argument.''' + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" " real, optional :: opt\n" - " !if( present(opt) )then\n" - " ! x = x + opt\n" - " !end if\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" " x = x + 1.0\n" "end subroutine sub\n" "end module test_mod\n" @@ -2138,10 +2321,166 @@ def test_validate_named_arg(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() - with pytest.raises(TransformationError) as err: - inline_trans.validate(call) - assert ("Routine 'sub' cannot be inlined because it has a named argument " - "'opt' (TODO #924)" in str(err.value)) + inline_trans.apply(call) + + +def test_apply_optional_arg_with_special_cases(fortran_reader): + '''Test that the validate method inlines a routine + that has an optional argument. + This example has an additional if-branching condition + `1.0==1.0` which is not directly of type `Literal` + ''' + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" + " if( 1.0 == 1.0 )then\n" + " x = x\n" + " end if\n" + " x = x + 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.apply(call) + + +def test_apply_optional_arg_error(fortran_reader): + '''Test that the validate method can't inline a routine + where the optional argument is still used. + ''' + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" + " x = x + opt\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + with pytest.raises(TransformationError) as einfo: + inline_trans.apply(call) + + assert ("Subroutine argument 'opt' is not provided by call," + " but used in the subroutine." in str(einfo.value)) + + +def test_apply_unsupported_pointer_error(fortran_reader): + '''Test that the validate method can't inline a routine + where a pointer argument is used. + This covers a special code + `if ", OPTIONAL" not in sym.datatype.declaration:` + which doesn't work that reliably and should be replaced + with something more robust. + ''' + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var)\n" + "end subroutine main\n" + "subroutine sub(x)\n" + " real, intent(inout), pointer :: x\n" + " x = 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + with pytest.raises(TransformationError) as einfo: + inline_trans.apply(call) + + assert ("Routine 'sub' cannot be inlined because it contains a Symbol 'x'" + " which is an Argument of UnsupportedType:" + " 'REAL, INTENT(INOUT), POINTER :: x'." in str(einfo.value)) + + +def test_apply_optional_and_named_arg_2(fortran_reader): + '''Test that the validate method inlines a routine + that has an optional argument.''' + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, 1.0)\n" + " ! Result:\n" + " ! var = var + 2.0 + 1.0\n" + " ! var = var + 4.0 + 1.0\n" + " ! var = var + 5.0 + 1.0\n" + " call sub(var)\n" + " ! Result:\n" + " ! var = var + 3.0\n" + " ! var = var + 6.0\n" + " ! var = var + 7.0\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " if( present(opt) )then\n" + " x = x + 2.0 + opt\n" + " else\n" + " x = x + 3.0\n" + " end if\n" + " if( present(opt) )then\n" + " x = x + 4.0 + opt\n" + " x = x + 5.0 + opt\n" + " else\n" + " x = x + 6.0\n" + " x = x + 7.0\n" + " end if\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir: Node = fortran_reader.psyir_from_source(code) + + inline_trans = InlineTrans() + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + for call in psyir.walk(Call, stop_type=Call): + call: Call + if call.routine.name != "sub": + continue + + inline_trans.apply(call) + + assert ( + '''var = var + 2.0 + 1.0 + var = var + 4.0 + 1.0 + var = var + 5.0 + 1.0 + var = var + 3.0 + var = var + 6.0 + var = var + 7.0''' + in routine_main.debug_string() + ) CALL_IN_SUB_USE = ( @@ -2194,47 +2533,3 @@ def test_apply_merges_symbol_table_with_routine(fortran_reader): inline_trans.apply(routine) # The i_1 symbol is the renamed i from the inlined call. assert psyir.walk(Routine)[0].symbol_table.get_symbols()['i_1'] is not None - - -def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): - ''' - Check that the formal arguments to the inlined routine are not included - when checking for clashes (since they will be replaced by the actual - arguments to the call). - ''' - - code_clash = """ - subroutine sub(Istr) - integer :: Istr - real :: x - x = 2.0*x - call sub_sub(Istr) - end subroutine sub - - subroutine sub_sub(Istr) - integer :: i - integer :: Istr - real :: b(10) - - b(Istr:10) = 1.0 - end subroutine sub_sub""" - - psyir = fortran_reader.psyir_from_source(code_clash) - call = psyir.walk(Call)[0] - inline_trans = InlineTrans() - inline_trans.apply(call) - expected = '''\ -subroutine sub(istr) - integer :: istr - real :: x - integer :: i - real, dimension(10) :: b - - x = 2.0 * x - b(istr:) = 1.0 - -end subroutine sub -''' - output = fortran_writer(psyir) - assert expected in output - assert Compile(tmpdir).string_compiles(output) diff --git a/src/psyclone/tests/psyir/transformations/omp_task_transformations_test.py b/src/psyclone/tests/psyir/transformations/omp_task_transformations_test.py index 47e2896edf..b1d2e9fa9d 100644 --- a/src/psyclone/tests/psyir/transformations/omp_task_transformations_test.py +++ b/src/psyclone/tests/psyir/transformations/omp_task_transformations_test.py @@ -178,7 +178,8 @@ def test_omptask_apply_kern(fortran_reader, fortran_writer): new_container.addchild(my_test) sym = my_test.symbol_table.lookup("test_kernel") sym.interface.container_symbol._reference = test_kernel_mod - trans = OMPTaskTrans() + trans: OMPTaskTrans = OMPTaskTrans() + trans.set_option(check_matching_arguments_of_callee=False) master = OMPSingleTrans() parallel = OMPParallelTrans() calls = my_test.walk(Call) diff --git a/utils/run_pytest.sh b/utils/run_pytest.sh new file mode 100755 index 0000000000..635bccab40 --- /dev/null +++ b/utils/run_pytest.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" +cd "$SCRIPTPATH/.." + +NPROC=$(nproc) + +echo "Running 'pytest -n ${NPROC} src/psyclone/tests'" +time pytest -n ${NPROC} src/psyclone/tests + diff --git a/utils/run_pytest_cov.sh b/utils/run_pytest_cov.sh index 7ae7bd36c4..f8f52e98d3 100755 --- a/utils/run_pytest_cov.sh +++ b/utils/run_pytest_cov.sh @@ -23,10 +23,13 @@ COV_REPORT="xml:cov.xml" # Additional options # Also write to Terminal -#OPTS=" --cov-report term" +OPTS=" --cov-report term" + +if [[ -e cov.xml ]]; then + echo "Removing previous reporting file 'cov.xml'" + rm -rf cov.xml +fi -#echo "Running 'pytest --cov $PSYCLONE_MODULE --cov-report term-missing -n $(nproc) $SRC_DIR'" -#pytest --cov $PSYCLONE_MODULE -v --cov-report term-missing -n $(nproc) $SRC_DIR echo "Running 'pytest --cov $PSYCLONE_MODULE --cov-report $COV_REPORT -n $(nproc) $SRC_DIR'" -pytest --cov $PSYCLONE_MODULE -v --cov-report $COV_REPORT $OPTS -n $(nproc) $SRC_DIR +time pytest --cov $PSYCLONE_MODULE --cov-report $COV_REPORT $OPTS -n $(nproc) $SRC_DIR