diff --git a/src/psyclone/psyir/symbols/symbol_table.py b/src/psyclone/psyir/symbols/symbol_table.py index 2d74cd1369..768c27412a 100644 --- a/src/psyclone/psyir/symbols/symbol_table.py +++ b/src/psyclone/psyir/symbols/symbol_table.py @@ -53,6 +53,7 @@ from psyclone.psyir.symbols import ( DataSymbol, ContainerSymbol, DataTypeSymbol, GenericInterfaceSymbol, ImportInterface, RoutineSymbol, Symbol, SymbolError, UnresolvedInterface) +from psyclone.psyir.symbols.datatypes import ScalarType from psyclone.psyir.symbols.intrinsic_symbol import IntrinsicSymbol from psyclone.psyir.symbols.typed_symbol import TypedSymbol @@ -1188,15 +1189,6 @@ def remove(self, symbol): raise TypeError(f"remove() expects a Symbol argument but found: " f"'{type(symbol).__name__}'.") - # pylint: disable=unidiomatic-typecheck - if not (isinstance(symbol, (ContainerSymbol, RoutineSymbol)) or - type(symbol) is Symbol): - raise NotImplementedError( - f"remove() currently only supports generic Symbol, " - f"ContainerSymbol and RoutineSymbol types but got: " - f"'{type(symbol).__name__}'") - # pylint: enable=unidiomatic-typecheck - # Since we are manipulating the _symbols dict directly we must use # the normalised name of the symbol. norm_name = self._normalize(symbol.name) @@ -1225,6 +1217,39 @@ def remove(self, symbol): # target of a Call or a member of a GenericInterfaceSymbol. if isinstance(symbol, RoutineSymbol): self._validate_remove_routinesymbol(symbol) + elif self.node: + from psyclone.psyir.nodes import ScopingNode + from psyclone.core.variables_access_info import VariablesAccessInfo + vai = VariablesAccessInfo() + self.node.reference_accesses(vai) + for sig in vai.all_signatures: + if sig.var_name.lower() != norm_name: + continue + # The variable associated with this signature has the same + # name as the target symbol. Now look at each access... + for access in vai[sig].all_accesses: + if isinstance(access.node, ScopingNode): + continue + # We need to know whether this access is the actual + # symbol we want to remove, not whether it just has the + # same name... + # TODO + if access.node.scope.symbol_table.lookup(norm_name) is symbol: + # This access does refer to the target symbol. + if symbol.find_symbol_table(access.node) is self: + # The symbol we've found an access of is the one + # in this table. Therefore, we can only remove it + # provided that it also exists in an outer scope. + outer_sym = self.parent_symbol_table().lookup( + norm_name, otherwise=None) + if outer_sym is not symbol: + from psyclone.psyir.nodes import Statement + stmt = access.node.ancestor(Statement) + if not stmt: + stmt = access.node + raise ValueError( + f"Cannot remove {type(symbol).__name__} '{symbol.name}' because it is " + f"accessed in '{stmt.debug_string().strip()}'") # If the symbol had a tag, it should be disassociated for tag, tagged_symbol in list(self._tags.items()): @@ -1652,11 +1677,11 @@ def resolve_imports(self, container_symbols=None, symbol_target=None): continue # Examine all Symbols defined within this external container - for symbol in external_container.symbol_table.symbols: - if symbol.visibility == Symbol.Visibility.PRIVATE: + for imported_sym in external_container.symbol_table.symbols: + if imported_sym.visibility == Symbol.Visibility.PRIVATE: continue # We must ignore this symbol - if isinstance(symbol, ContainerSymbol): + if isinstance(imported_sym, ContainerSymbol): # TODO #1540: We also skip other ContainerSymbols but in # reality if this is a wildcard import we would have to # process the nested external container. @@ -1665,58 +1690,17 @@ def resolve_imports(self, container_symbols=None, symbol_target=None): # If we are just resolving a single specific symbol we don't # need to process this symbol unless the name matches. if symbol_target and not self._has_same_name( - symbol, symbol_target): + imported_sym, symbol_target): continue - # Determine if there is an Unresolved Symbol in a - # descendent symbol table that matches the name of the - # symbol we are importing and if so, move it to this - # symbol table if a symbol with the same name does not - # already exist in this symbol table. + norm_name = self._normalize(imported_sym.name) - # There are potential issues with this approach and - # with the routine in general which are captured in - # issue #2331. Issue #2271 may also help/fix some or - # all of the problems too. - - # Import here to avoid circular dependencies - # pylint: disable=import-outside-toplevel - from psyclone.psyir.nodes import ScopingNode, Reference - for scoping_node in self.node.walk(ScopingNode): - symbol_table = scoping_node.symbol_table - test_symbol = symbol_table.lookup(symbol.name, - otherwise=None) - if (test_symbol and test_symbol.is_unresolved and - all(csym in self.containersymbols for - csym in symbol_table.wildcard_imports())): - # No wildcard imports into this scope. - symbol_table.remove(test_symbol) - if test_symbol.name not in self: - # The visibility given by the inner symbol - # table does not necessarily match the one - # from the scope it should have been in (it - # doesn't have a non-default visibility, - # otherwise the symbol would already be in - # the ancestor symbol table). - test_symbol.visibility = self.default_visibility - - self.add(test_symbol) - else: - # There is already a symbol with this name - # in this table. Update all references to - # point to it. - for ref in symbol_table.node.walk(Reference): - if SymbolTable._has_same_name( - ref.symbol, symbol): - mod_symbol = self.lookup(symbol.name) - ref.symbol = mod_symbol - - # This Symbol matches the name of a symbol in the current table - if symbol.name in self: - - symbol_match = self.lookup(symbol.name) - interface = symbol_match.interface - visibility = symbol_match.visibility + if norm_name in self: + # This Symbol matches the name of a symbol in the current + # table + outer_sym = self.lookup(norm_name) + interface = outer_sym.interface + visibility = outer_sym.visibility # If the import statement is not a wildcard import, the # matching symbol must have the appropriate interface @@ -1737,45 +1721,116 @@ def resolve_imports(self, container_symbols=None, symbol_target=None): pass else: raise SymbolError( - f"Found a name clash with symbol '{symbol.name}' " + f"Found a name clash with symbol '{imported_sym.name}' " f"when importing symbols from container " f"'{c_symbol.name}'.") # If the external symbol is a subclass of the local # symbol_match, copy the external symbol properties, # otherwise ignore this step. - if isinstance(symbol, type(symbol_match)): + if isinstance(imported_sym, type(outer_sym)): # pylint: disable=unidiomatic-typecheck - if type(symbol) is not type(symbol_match): - if isinstance(symbol, TypedSymbol): + if type(imported_sym) is not type(outer_sym): + if isinstance(imported_sym, TypedSymbol): # All TypedSymbols have a mandatory datatype # argument - symbol_match.specialise( - type(symbol), datatype=symbol.datatype) + outer_sym.specialise( + type(imported_sym), + datatype=imported_sym.datatype) else: - symbol_match.specialise(type(symbol)) + outer_sym.specialise(type(imported_sym)) - symbol_match.copy_properties(symbol) + outer_sym.copy_properties(imported_sym) # Restore the interface and visibility as these are # local (not imported) properties - symbol_match.interface = interface - symbol_match.visibility = visibility - if symbol_target: - # If we were looking just for this symbol we don't need - # to continue searching - return + outer_sym.interface = interface + outer_sym.visibility = visibility else: + # This table did not already contain a symbol with this + # name. if c_symbol.wildcard_import: # This symbol is PUBLIC and inside a wildcard import, # so it needs to be declared in the symbol table. - new_symbol = symbol.copy() - new_symbol.interface = ImportInterface(c_symbol) - new_symbol.visibility = self.default_visibility - self.add(new_symbol) - if symbol_target: - # If we were looking just for this symbol then - # we're done. - return + outer_sym = imported_sym.copy() + outer_sym.interface = ImportInterface(c_symbol) + outer_sym.visibility = self.default_visibility + self.add(outer_sym) + + # Determine if there is an Unresolved Symbol in a + # descendent symbol table that matches the name of the + # symbol we are importing. If so, move it to this symbol table + # provided that a symbol with the same name does not + # already exist in this symbol table. + + # There are potential issues with this approach and + # with the routine in general which are captured in + # issue #2331. Issue #2271 may also help/fix some or + # all of the problems too. + + # Import here to avoid circular dependencies + # pylint: disable=import-outside-toplevel + from psyclone.psyir.nodes import Call, Literal, ScopingNode, Reference + # Walk down through the scopes below this one. + for scoping_node in self.node.walk(ScopingNode): + if scoping_node is self.node: + # Skip ourself. + continue + symbol_table = scoping_node.symbol_table + test_symbol = symbol_table.lookup(norm_name, + scope_limit=scoping_node, + otherwise=None) + if not test_symbol or not test_symbol.is_unresolved: + # Either this table doesn't contain a symbol with the + # same name as the imported one or it does but it is + # not unresolved so we ignore it. + continue + wildcard_imports = symbol_table.wildcard_imports() + if not all(csym in self.containersymbols for + csym in wildcard_imports): + # TODO, check whether the wildcard imports are all at + # the same scope. + import pdb; pdb.set_trace() + # There are wildcard imports other than those in the + # outer scope so we can't be certain of the origin of + # this symbol. + continue + + # We want to replace the local symbol with the new one + # in the outer scope (`outer_sym`). + from psyclone.core.variables_access_info import VariablesAccessInfo + vai = VariablesAccessInfo() + scoping_node.reference_accesses(vai) + for sig in vai.all_signatures: + if sig.var_name.lower() != norm_name: + continue + if norm_name == "wp": + import pdb; pdb.set_trace() + for access in vai[sig].all_accesses: + if access.node.scope is access.node: + # This is just the symbol table associated + # with a ScopingNode. + continue + if access.node.scope.symbol_table.lookup(norm_name) is test_symbol: + if isinstance(access.node, Reference): + access.node.symbol = outer_sym + elif isinstance(access.node, Call): + import pdb; pdb.set_trace() + print(access.node) + elif isinstance(access.node, Literal): + oldtype = access.node.datatype + newtype = ScalarType(oldtype.intrinsic, + outer_sym) + access.node.replace_with( + Literal(access.node.value, newtype)) + else: + import pdb; pdb.set_trace() + print("oh dear") + symbol_table.remove(test_symbol) + + if symbol_target: + # If we were looking just for this symbol we don't need + # to continue searching + return if symbol_target: raise KeyError( diff --git a/src/psyclone/tests/psyir/symbols/symbol_table_test.py b/src/psyclone/tests/psyir/symbols/symbol_table_test.py index 0017319b6f..40367517b5 100644 --- a/src/psyclone/tests/psyir/symbols/symbol_table_test.py +++ b/src/psyclone/tests/psyir/symbols/symbol_table_test.py @@ -527,7 +527,7 @@ def test_no_remove_routinesymbol_called(fortran_reader): with pytest.raises(ValueError) as err: table.remove(my_sub) assert ("Cannot remove RoutineSymbol 'my_sub' because it is referenced by " - "'call my_sub()" in str(err)) + "'call my_sub()" in str(err.value)) # Add the routine symbol into the filecontainer then we should be able # to remove it from the module - this validates the @@ -627,10 +627,7 @@ def test_remove_unsupported_types(): # We should not be able to remove a Symbol that is not currently supported var1 = symbols.DataSymbol("var1", symbols.REAL_TYPE) sym_table.add(var1) - with pytest.raises(NotImplementedError) as err: - sym_table.remove(var1) - assert ("remove() currently only supports generic Symbol, ContainerSymbol " - "and RoutineSymbol types but got: 'DataSymbol'" in str(err.value)) + sym_table.remove(var1) @pytest.mark.parametrize("sym_name", ["var1", "vAr1", "VAR1"])