Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration tests for #2828 #2879

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3047fdd
[ReplaceRefByLiteral] transformation first draft
hbrunie Dec 16, 2024
04573fe
[ReplaceRefByLiteral] replace in ArrayType shape.
hbrunie Dec 17, 2024
4ee2824
[gitignore] venv and .vscode
hbrunie Dec 17, 2024
8d3c9b4
[ReplaceRefByLiteral] initial value always literal
hbrunie Dec 17, 2024
fdb3738
[ReplaceRefByLiteral] 80 chars exceeded
hbrunie Dec 17, 2024
b19b6a7
[ReplaceRefByLiteral] add test showing limit of trans
hbrunie Dec 17, 2024
bfe75cc
[ReplaceRefByLit] flake8
hbrunie Dec 17, 2024
5b06c5b
[RefbyLiteral] flake8 + test sucess
hbrunie Dec 17, 2024
5387311
Merge branch 'master' into replace_ref_by_literal
hbrunie Dec 17, 2024
dcf29a4
Merge branch 'master' into replace_ref_by_literal
hbrunie Jan 15, 2025
ff98ba7
#2828 fix license date.
hbrunie Jan 15, 2025
c31f4b9
#2828 fix module description.
hbrunie Jan 20, 2025
257aac7
#2828 class docstring and import sorted.
hbrunie Jan 20, 2025
9e41bcc
#2828 fix indent
hbrunie Jan 20, 2025
94be55c
#2828 remove __str__ method
hbrunie Jan 20, 2025
cf253a7
#2828 use new comment feature (thx JR)
hbrunie Jan 20, 2025
7bee936
#2828 comment each test
hbrunie Jan 20, 2025
463e64d
#2828 explicit warning message.
hbrunie Jan 20, 2025
4b61ea0
#2828 user friendly comment
hbrunie Jan 20, 2025
1e6089f
#2828 missing docstring
hbrunie Jan 20, 2025
880ce34
#2828 docstrings
hbrunie Jan 20, 2025
113565e
#2828 cover more edge cases.
hbrunie Jan 20, 2025
1a54922
#2828 remove outdated comment
hbrunie Jan 20, 2025
517682a
#2828 minor cleaning
hbrunie Jan 20, 2025
207cda3
#2828 fix sphinx make html
hbrunie Jan 20, 2025
b134774
#2828 fix flake8
hbrunie Jan 23, 2025
725b03d
#2828 introducing cls constant for test
hbrunie Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ src/*.egg-info
.idea
.rtx.toml
.venv
.vscode
venv
cov.xml
.coverage.*
*.psycache
Expand Down
84 changes: 45 additions & 39 deletions src/psyclone/psyir/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,47 +103,53 @@
from psyclone.psyir.transformations.region_trans import RegionTrans
from psyclone.psyir.transformations.replace_induction_variables_trans import \
ReplaceInductionVariablesTrans
from psyclone.psyir.transformations.replace_reference_by_literal_trans import (
ReplaceReferenceByLiteralTrans,
)
from psyclone.psyir.transformations.reference2arrayrange_trans import \
Reference2ArrayRangeTrans


# For AutoAPI documentation generation
__all__ = ['ACCKernelsTrans',
'ACCUpdateTrans',
'AllArrayAccess2LoopTrans',
'ArrayAccess2LoopTrans',
'ArrayAssignment2LoopsTrans',
'ChunkLoopTrans',
'ExtractTrans',
'FoldConditionalReturnExpressionsTrans',
'HoistLocalArraysTrans',
'HoistLoopBoundExprTrans',
'HoistTrans',
'InlineTrans',
'Abs2CodeTrans',
'DotProduct2CodeTrans',
'Matmul2CodeTrans',
'Max2CodeTrans',
'Min2CodeTrans',
'Sign2CodeTrans',
'Sum2LoopTrans',
'LoopFuseTrans',
'LoopSwapTrans',
'LoopTiling2DTrans',
'LoopTrans',
'Maxval2LoopTrans',
'Minval2LoopTrans',
'OMPLoopTrans',
'OMPTargetTrans',
'OMPTaskTrans',
'OMPTaskwaitTrans',
'ParallelLoopTrans',
'Product2LoopTrans',
'ProfileTrans',
'PSyDataTrans',
'ReadOnlyVerifyTrans',
'Reference2ArrayRangeTrans',
'RegionTrans',
'ReplaceInductionVariablesTrans',
'TransformationError',
'ValueRangeCheckTrans']
__all__ = [
"ACCKernelsTrans",
"ACCUpdateTrans",
"AllArrayAccess2LoopTrans",
"ArrayAccess2LoopTrans",
"ArrayAssignment2LoopsTrans",
"ChunkLoopTrans",
"ExtractTrans",
"FoldConditionalReturnExpressionsTrans",
"HoistLocalArraysTrans",
"HoistLoopBoundExprTrans",
"HoistTrans",
"InlineTrans",
"Abs2CodeTrans",
"DotProduct2CodeTrans",
"Matmul2CodeTrans",
"Max2CodeTrans",
"Min2CodeTrans",
"Sign2CodeTrans",
"Sum2LoopTrans",
"LoopFuseTrans",
"LoopSwapTrans",
"LoopTiling2DTrans",
"LoopTrans",
"Maxval2LoopTrans",
"Minval2LoopTrans",
"OMPLoopTrans",
"OMPTargetTrans",
"OMPTaskTrans",
"OMPTaskwaitTrans",
"ParallelLoopTrans",
"Product2LoopTrans",
"ProfileTrans",
"PSyDataTrans",
"ReadOnlyVerifyTrans",
"Reference2ArrayRangeTrans",
"RegionTrans",
"ReplaceInductionVariablesTrans",
"ReplaceReferenceByLiteralTrans",
"TransformationError",
"ValueRangeCheckTrans",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# -----------------------------------------------------------------------------
# BSD 3-Clause License
#
# Copyright (c) 2024-2025, Science and Technology Facilities Council.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Author: H. Brunie, University of Grenoble Alpes

"""Module providing a transformation that replace PsyIR Node representing a
static, constant value with a Literal Node when possible. """

from typing import Dict, List, Union

from psyclone.psyGen import Transformation
from psyclone.psyir.nodes import (
Container,
DataNode,
Literal,
Reference,
Routine,
)
from psyclone.psyir.symbols import (
ArrayType,
DataSymbol,
SymbolTable,
UnsupportedFortranType,
)
from psyclone.psyir.transformations.transformation_error import (
TransformationError,
)


class ReplaceReferenceByLiteralTrans(Transformation):
'''
This transformation takes a psyir Routine and replace all Reference psyir
Nodes by Literal if the corresponding symbol from the symbol table is
constant. That is to say the symbol is a Fortran parameter.
For example:


>>> from psyclone.psyir.backend.fortran import FortranWriter
>>> from psyclone.psyir.symbols import INTEGER_TYPE
>>> from psyclone.psyir.transformations import (
ReplaceReferenceByLiteralTrans)
>>> source = """program test
... use mymod
... type(my_type):: t1, t2, t3, t4
... integer, parameter :: x=3, y=12, z=13
... integer, parameter :: u1=1, u2=2, u3=3, u4=4
... integer i, invariant, ic1, ic2, ic3
... real, dimension(10) :: a
... invariant = 1
... do i = 1, 10
... t1%a = z
... a(ic1) = u1+(ic1+x)*ic1
... a(ic2) = u2+(ic2+y)*ic2
... a(ic3) = u3+(ic3+z)*ic3
... a(t1%a) = u4+(t1%a+u4*z)*t1%a
... end do
... end program test"""
>>> fortran_writer = FortranWriter()
>>> fortran_reader = FortranReader()
>>> psyir = fortran_reader.psyir_from_source(source)
>>> routine = psyir.walk(Routine)[0]
>>> rrbl = ReplaceReferenceByLiteralTrans()
>>> rrbl.apply(routine)
>>> written_code = fortran_writer(routine)
>>> print(written_code)
program test
use mymod
integer, parameter :: x = 3
integer, parameter :: y = 12
integer, parameter :: z = 13
integer, parameter :: u1 = 1
integer, parameter :: u2 = 2
integer, parameter :: u3 = 3
integer, parameter :: u4 = 4
type(my_type) :: t1
type(my_type) :: t2
type(my_type) :: t3
type(my_type) :: t4
integer :: i
integer :: invariant
integer :: ic1
integer :: ic2
integer :: ic3
real, dimension(10) :: a
<BLANKLINE>
invariant = 1
do i = 1, 10, 1
t1%a = 13
a(ic1) = 1 + (ic1 + 3) * ic1
a(ic2) = 2 + (ic2 + 12) * ic2
a(ic3) = 3 + (ic3 + 13) * ic3
a(t1%a) = 4 + (t1%a + 4 * 13) * t1%a
enddo
<BLANKLINE>
end program test
<BLANKLINE>

'''

_ERROR_MSG = (
"Psyclone(ReplaceReferenceByLiteral): only "
+ "supports symbols which have a Literal as their initial value but "
)

def __init__(self) -> None:
super().__init__()
# Dictionary with Literal values of the corresponding symbol
# from symbol_table (based on symbol name as a string).
self._param_table: Dict[str, Literal] = {}

def _update_param_table(
self,
param_table: Dict[str, Literal],
symbol_table: SymbolTable,
) -> Dict[str, Literal]:
"""This methods takes a param_table as entry, updates this dictionary
and then returns the same dictionary updated.

* Goes through all datasymbols in the symbol_table.
* if symbol already in param_table or initial_value is not Literal:
* annotate code with warning.
* copy and detach the symbol.initial_value (Literal)
* update the param_table with this copy.
* Returns the updated param_table

:param param_table: To be updated
:param symbol_table: scope symbol table to look for the symbols.
:return: the updated param_table (same reference as the entry one)
:rtype: Dict[str, Literal]
"""
for sym in symbol_table.datasymbols:
sym: DataSymbol
if sym.is_constant:
sym_name = sym.name
print(sym_name)
print(type(sym.initial_value))
print(param_table.get(sym_name))
if param_table.get(sym_name):
message = (
"Psyclone(ReplaceReferenceByLiteralTrans):"
+ f" Symbol already found {sym_name}."
+ " A conflict is possible."
+ "To avoid replacing by wrong value, "
+ "symbol is removed from param_table."
)
sym.preceding_comment += message
param_table.pop(sym_name)
continue
if not isinstance(sym.initial_value, Literal):
message = (
ReplaceReferenceByLiteralTrans._ERROR_MSG
+ f"{sym_name} is assigned "
+ f"a {type(sym.initial_value)}"
)
sym.preceding_comment += message
print(sym.preceding_comment)
continue
new_literal: Literal = sym.initial_value.copy().detach()
param_table[sym_name] = new_literal
return param_table

def _replace_bounds(
self,
current_shape: List[Union[Literal, Reference]],
param_table: Dict[str, Literal],
) -> List[Union[Literal, Reference]]:
"""From the param_table and the current_shape of an array,
this method create a new_shape with the reference replaced by literal
when they are found in the param_table.

:param current_shape: shape before transformation
:param param_table: table of parameters Literal values.
:return: the new shape with replaced reference by literal.
"""
new_shape = []
for dim in current_shape:
if isinstance(dim, ArrayType.ArrayBounds):
dim_upper = dim.upper.copy()
dim_lower = dim.lower.copy()
ref: DataNode = dim.upper
if isinstance(ref, Reference) and ref.name in param_table:
literal: Literal = param_table[ref.name]
dim_upper = literal.copy()
ref = dim.lower
if isinstance(ref, Reference) and ref.name in param_table:
literal: Literal = param_table[ref.name]
dim_lower = literal.copy()
new_bounds = ArrayType.ArrayBounds(dim_lower, dim_upper)
new_shape.append(new_bounds)
else:
# This dimension is specified with an ArrayType.Extent
# so no need to copy.
new_shape.append(dim)
return new_shape

# ------------------------------------------------------------------------
def apply(self, node: Routine, options=None):
"""Applies the transformation to a Routine node:
* First update a dictionary (param_table) with the Literal of constant
(parameter) symbol from node.parent symbol_table, and from
node.symbol_table.
* Second, use this updated param_table to replace reference in node
psyir_tree with the corresponsing Literal.
* Third, use this updated param_table to replace reference in node
symbol_table DataSymbol array's dimensions with the corresponsing
Literal.

:param node: _description_
:param options: not used, defaults to None
:type options: _type_, optional
"""
## Reset the param table for the current Routine
self._param_table = {}
self.validate(node, options)
## NOTE: (From Andrew) We may want to look at all symbols in scope
# rather than just those in the parent symbol table?
if node.parent is not None and isinstance(node.parent, Container):
if node.parent.symbol_table is not None:
self._param_table = self._update_param_table(
self._param_table, node.parent.symbol_table
)
## NOTE: and other parent scopes?
# symbol_table.parent_symbol_table

## node.symbol_table is not None (in validate)
self._param_table = self._update_param_table(
self._param_table, node.symbol_table
)

for ref in node.walk(Reference):
ref: Reference
if ref.name in self._param_table:
literal: Literal = self._param_table[ref.name]
ref.replace_with(literal.copy())

for sym in node.symbol_table.datasymbols:
sym: DataSymbol
if sym.is_array:
if not isinstance(sym.datatype, UnsupportedFortranType):
new_shape: List[Union[Literal, Reference]] = (
self._replace_bounds(sym.shape, self._param_table)
)
sym.datatype = ArrayType(sym.datatype.datatype, new_shape)

# ------------------------------------------------------------------------
def validate(self, node, options=None):
"""Perform various checks to ensure that it is valid to apply the
ReplaceReferenceByLiteralTrans transformation to the supplied PSyIR
Node.

:param node: the node that is being checked.
:param options: not used, defaults to None
:type options: _type_, optional
:raises TransformationError: if the node argument is not a Routine.
"""
if not isinstance(node, Routine):
raise TransformationError(
f"Error in {self.name} transformation. The supplied node "
f"argument should be a PSyIR Routine, but found "
f"'{type(node).__name__}'."
)


__all__ = ["ReplaceReferenceByLiteralTrans"]
Loading
Loading