Skip to content

Commit

Permalink
backup the changes & stop to implement #4425[node-merging] faster & w…
Browse files Browse the repository at this point in the history
…ait for future refactoring when needed
  • Loading branch information
Stevengre committed Jun 26, 2024
1 parent 35cba26 commit 2bd91f6
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 18 deletions.
12 changes: 8 additions & 4 deletions pyk/src/pyk/kcfg/minimize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Status: [LiftEdgeSplit, LiftSplitSplit, MergeSplitNodes] In Progress
"""

from .rewriter import KCFGRewriter, KCFGRewritePattern, NodeIdLike, KCFGRewriteWalker
from .kcfg import KCFG

Expand Down Expand Up @@ -36,16 +40,16 @@ def match_and_rewrite(self, node: NodeIdLike, rewriter: KCFGRewriter) -> bool:
return rewriter.commit(node, match_pattern, rewrite_pattern)


class LiftSplitSplit(KCFGRewritePattern):
class LiftEdgeSplit(KCFGRewritePattern):
def match_and_rewrite(self, node: NodeIdLike, rewriter: KCFGRewriter) -> bool:
match_pattern = ('S->N|split', 'N->T*|split',)
match_pattern = ('S->N|edge', 'N->T*|split',)
rewrite_pattern = ('S->S*|split', 'S*->T*|edge',)
return rewriter.commit(node, match_pattern, rewrite_pattern)


class LiftEdgeSplit(KCFGRewritePattern):
class LiftSplitSplit(KCFGRewritePattern):
def match_and_rewrite(self, node: NodeIdLike, rewriter: KCFGRewriter) -> bool:
match_pattern = ('S->N|edge', 'N->T*|split',)
match_pattern = ('S->N|split', 'N->T*|split',)
rewrite_pattern = ('S->S*|split', 'S*->T*|edge',)
return rewriter.commit(node, match_pattern, rewrite_pattern)

Expand Down
85 changes: 75 additions & 10 deletions pyk/src/pyk/kcfg/rewriter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
"""
Status: [create(), merge_paths(), get_or_create_node] In Progress
[match()] may have some bugs for `N->T|split` pattern.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import final

from .kcfg import KCFG, NodeIdLike
from ..cterm import CTerm, CSubst
from ..cterm import CTerm
from ..utils import single


@final
@dataclass(frozen=True)
class Pats:
class Pat:
"""
The matching patterns for the KCFG, starting from a node N.
Expand Down Expand Up @@ -62,11 +67,13 @@ class KCFGRewriter:
"""The nodes of the matched KCFG."""
_edges: dict[str, tuple[KCFG.Successor, ...]]
"""The edges of the matched KCFG."""
_node_id: int

def __init__(self, kcfg: KCFG):
self.kcfg = kcfg
self._nodes = {}
self._edges = {}
self._node_id = kcfg._node_id

def commit(
self,
Expand Down Expand Up @@ -120,7 +127,7 @@ def match(
if not patterns:
return self._get_matched()
# parse the patterns
patterns = [Pats(p) for p in patterns]
patterns = [Pat(p) for p in patterns]
# match the patterns
loop_count = -1
# the worst case is that only one pattern is matched in each loop
Expand Down Expand Up @@ -194,10 +201,7 @@ def _get_matched(
matched.add_successor(edge)
return matched

def create(
self,
rewrite_patterns: tuple[str, ...],
) -> KCFG:
def create(self, rewrite_patterns: tuple[str, ...],) -> KCFG:
"""
Create a new KCFG by rewriting the matched subgraph.
Expand All @@ -207,16 +211,18 @@ def create(
# initialization
nodes: dict[str, tuple[KCFG.Node, ...]] = {}
edges: dict[str, tuple[KCFG.Successor, ...]] = {}
patterns = [Pats(p) for p in rewrite_patterns]
patterns = tuple(Pat(p) for p in rewrite_patterns)
for pattern in patterns:
# todo: we just support patterns that we need for minimization. We should support more patterns if needed.
source = self.get_or_create_node(pattern.source, patterns)
source = self._nodes.get(pattern.source)
target = self._nodes.get(pattern.target)
if not source or not target:
raise NotImplementedError(f"Pattern {pattern} is not supported.")
match pattern.edge_type:
match pattern.edge_type: # todo: transfer it into a function, like `create_edge` or `create_successor`
case 'edge':
if pattern.is_multi_source and pattern.is_multi_target:

raise NotImplementedError(f"Pattern {pattern} is not supported.")
elif not pattern.is_multi_source and not pattern.is_multi_target:
assert len(source) == 1 and len(target) == 1, \
Expand Down Expand Up @@ -265,6 +271,65 @@ def merge_paths(self, source: int, target: int) -> tuple[int, tuple[str, ...]]:
continue
return depth, tuple(rules)

def get_or_create_node(
self,
node_name: str,
rewrite_patterns: tuple[Pat, ...]
) -> tuple[KCFG.Node, ...]:
"""
Get a node from self._nodes, or create a new node if it does not exist.
Input: The node pattern name.
Output: The node.
"""
def _create_node(content: CTerm) -> KCFG.Node:
self._node_id += 1
new_node = KCFG.Node(self._node_id, content)
return new_node

node = self._nodes.get(node_name)
if not node:
if node_name.endswith('*'): # node_name = S*
s_name = node_name[:-2]
s = single(self._nodes.get(s_name))
assert s, f"Node {s_name} should exist."
# is below necessary?
assert Pat(f"{s_name}->{node_name}|split") in rewrite_patterns, \
f"Pattern {s_name}->{node_name}|split should exist in rewrite patterns."
# discover the constraints of the new node
# Now, discover the constraints via 'S*->T*|edge' in rewrite pattern; and 'N->T*|split' in match pattern
# todo: add more discovery methods when needed
pattern = single(p for p in rewrite_patterns if p.source == node_name)
assert pattern.target.endswith('*') and pattern.edge_type == 'edge', \
f"Pattern {pattern} is not supported for discovering constraints for {node}"
# endswith '->T*|split' in match pattern
match_edge = single(self._edges.get(e_id)
for e_id in self._edges.keys() if e_id.endswith(f"->{pattern.target}|split"))
assert len(match_edge) == 1, f"Match edge for {pattern} should be unique."
match_edge = match_edge[0]
assert isinstance(match_edge, KCFG.Split), f"Match edge for {pattern} should be Split."
t_i, t_substs = list(match_edge.splits.keys()), list(match_edge.splits.values())
# Ensure split can be lifted soundly (i.e., that it does not introduce fresh variables)
assert (
len(match_edge.source_vars.difference(s.free_vars)) == 0
and len(match_edge.target_vars.difference(match_edge.source_vars)) == 0
)
# Create CTerms corresponding to the new targets of the split
new_cterms = [
CTerm(s.cterm.config, s.cterm.constraints + csubst.constraints) for csubst in
t_substs
]
node = tuple(_create_node(cterm) for cterm in new_cterms)
else:
s_name = node_name + '*'
s = self._nodes.get(s_name)
assert s, f"Node {s_name} should exist."
# is below necessary?
assert Pat(f"{node_name}->{s_name}|split") in rewrite_patterns, \
f"Pattern {node_name}->{s_name}|split should exist in rewrite patterns."
raise NotImplementedError("Not implemented yet.") # todo: when implementing MergeSplitNodes
return tuple(node)

# todo: move it to kcfg.py, maybe as a argument of remove_node
def remove_node_safe(self, node_id: int) -> bool:
"""
Expand Down
55 changes: 53 additions & 2 deletions pyk/src/tests/unit/kcfg/test_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pyk.kast.inner import KVariable
from pyk.kcfg import KCFG
from pyk.kcfg.minimize import LiftEdgeEdge
from pyk.kcfg.minimize import LiftEdgeEdge, LiftEdgeSplit
from pyk.kcfg.rewriter import KCFGRewriteWalker, KCFGRewritePattern
from pyk.prelude.kint import geInt, intToken, ltInt
from pyk.prelude.ml import mlEqualsTrue
Expand Down Expand Up @@ -52,8 +52,52 @@ def kcfg_two_splits_lifted_edge_edge() -> KCFG:
return cfg


def kcfg_two_splits_lifted_edge_edge_lifted_edge_split() -> KCFG:
# 50 25
# /-- X >=Int 5 --> 1** (17) --> 6* (19) --> 10
# /-- X >=Int 0 --> 1* (15) 50 25
# 1 \-- X <Int 5 --> 1** (18) --> 6* (20) --> 11
# \ 50 105
# \-- X <Int 0 --> 1* (16) --> 7 --> 13

# todo: I don't implement this structure for now

x_ge_0 = mlEqualsTrue(geInt(KVariable('X'), intToken(0)))
x_lt_0 = mlEqualsTrue(ltInt(KVariable('X'), intToken(0)))
x_ge_5 = mlEqualsTrue(geInt(KVariable('X'), intToken(5)))
x_lt_5 = mlEqualsTrue(ltInt(KVariable('X'), intToken(5)))

d = {
'next': 21,
'nodes': node_dicts(20, config=x_config()),
'edges': edge_dicts(
(17, 19, 50, ('r1', 'r2', 'r3', 'r4')),
(18, 20, 50, ('r1', 'r2', 'r3', 'r4')),
(19, 10, 25, ('r5',)),
(20, 11, 25, ('r5',)),
(16, 7, 50, ('r1', 'r2', 'r3', 'r4')),
(7, 13, 105, ('r6', 'r7', 'r8')),
),
'splits': split_dicts((1, [(15, x_ge_0), (16, x_lt_0)]), (15, [(17, x_ge_5), (18, x_lt_5)]), csubst=x_subst()),
}
d = pop_nodes_by_ids(d, {2, 3, 4, 5, 8, 9, 12})
cfg = KCFG.from_dict(d)
propagate_split_constraints(cfg)
return cfg


# todo: lifted edge-split
# 5 10 15 20 25
# /-- X >=Int 5 --> 1** --> 2** --> 3** --> 4** --> 6* --> 10
# /-- X >=Int 0 --> 1* 5 10 15 20 25
# 1 \-- X <Int 5 --> 1** --> 2** --> 3** --> 4** --> 6* --> 11
# \ 5 10 15 20 25 30 35 40
# \-- X <Int 0 --> 1* --> 2* --> 3* --> 4* --> 7 --> 9 --> 12 --> 13


TWO_SPLITS = minimization_test_kcfg()
TWO_SPLITS_LIFTED_EDGE_EDGE = kcfg_two_splits_lifted_edge_edge()
TWO_SPLITS_LEE_LES = kcfg_two_splits_lifted_edge_edge_lifted_edge_split()


# ------------------------------
Expand All @@ -65,11 +109,18 @@ def assert_no_edge_edge(kcfg: KCFG):
assert not (kcfg.edges(source_id=node.id) and kcfg.edges(target_id=node.id))


def assert_no_edge_split(kcfg: KCFG):
# no A->B->C*
for node in kcfg.nodes:
assert not (kcfg.edges(source_id=node.id) and kcfg.splits(target_id=node.id))


# ------------------------------
# Test Cases
# ------------------------------
@pytest.mark.parametrize("patterns, input_kcfg, assertions, expected_kcfg", [
([LiftEdgeEdge()], TWO_SPLITS, [assert_no_edge_edge], TWO_SPLITS_LIFTED_EDGE_EDGE),
# ([LiftEdgeEdge()], TWO_SPLITS, [assert_no_edge_edge], TWO_SPLITS_LIFTED_EDGE_EDGE),
([LiftEdgeSplit()], TWO_SPLITS_LIFTED_EDGE_EDGE, [assert_no_edge_split], TWO_SPLITS_LEE_LES),
])
def test_minimize_patterns(
patterns: list[KCFGRewritePattern],
Expand Down

0 comments on commit 2bd91f6

Please sign in to comment.