-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ability to convert from SAT result to LTL formula
- Loading branch information
Showing
6 changed files
with
124 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,69 @@ | ||
import logging | ||
|
||
from z3 import And, Solver, is_true | ||
|
||
from ltl_learner.constants import operators | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Node: | ||
def __init__(self, id, label, left = None, right = None): | ||
self.label = label | ||
self.left = left | ||
self.right = right | ||
self.id = id | ||
|
||
def __str__(self): | ||
acc = f'{self.label}' | ||
if self.label in operators['all']: | ||
acc += '(' | ||
if self.left: | ||
acc += f'{self.left}' | ||
if self.right: | ||
acc += f',{self.right}' | ||
acc += ')' | ||
return acc | ||
|
||
class Tree: | ||
def __init__(self): | ||
self.root = None | ||
|
||
def add_node(self, node): | ||
self.nodes.append(node) | ||
|
||
def get_formula(self): | ||
return str(self.root) | ||
|
||
|
||
class LTLConverter: | ||
def __init__(self, solver: Solver): | ||
self.solver = solver | ||
|
||
def build(self): | ||
psi = self.solver.model() | ||
true_vars = {x.name(): x for x in psi.decls() if is_true(psi[x])} | ||
dag = [x for x in true_vars.keys() if x.startswith('x_') or x.startswith('l_') or x.startswith('r_')] | ||
ys = [y for y in true_vars.keys() if y.startswith('y_')] | ||
print(dag) | ||
print(ys) | ||
def build(self, length: int, true_nodes = None): | ||
if not true_nodes: | ||
psi = self.solver.model() | ||
true_vars = {x.name(): x for x in psi.decls() if is_true(psi[x])} | ||
dag = [x for x in true_vars.keys() if x.startswith('x_') or x.startswith('l_') or x.startswith('r_')] | ||
logger.info(f"Variables set to true: {dag}") | ||
true_nodes = list(sorted(dag, key = lambda n: n.split('_')[1])) | ||
tree = Tree() | ||
nodes = {} | ||
for i in range(length): | ||
label = [n for n in true_nodes if n.startswith(f'x_{i}_')][0].split('_')[-1] | ||
nodes[i] = Node(i, label) | ||
for i in range(length): | ||
left = [n for n in true_nodes if n.startswith(f'l_{i}_')] | ||
if left: | ||
left = int(left[0].split('_')[-1]) | ||
nodes[i].left = nodes[left] | ||
right = [n for n in true_nodes if n.startswith(f'r_{i}_')] | ||
if right: | ||
right = int(right[0].split('_')[-1]) | ||
nodes[i].right = nodes[right] | ||
tree.root = nodes[length - 1] | ||
logger.info('Computed tree from SAT assignation.') | ||
logger.info(f' {tree}') | ||
logger.info('LTL Formula:') | ||
logger.info(f' {tree.get_formula()}') | ||
return tree.get_formula() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pytest | ||
from z3 import Solver | ||
|
||
from ltl_learner.ltl.converter import LTLConverter | ||
|
||
|
||
@pytest.fixture | ||
def converter(): | ||
return LTLConverter(Solver()) | ||
|
||
|
||
@pytest.fixture | ||
def result_length_7(): | ||
return [ | ||
'x_6_U', | ||
'l_6_4', | ||
'x_4_!', | ||
'l_4_3', | ||
'x_3_F', | ||
'l_3_2', | ||
'x_2_&', | ||
'l_2_1', | ||
'x_1_crit2', | ||
'r_2_0', | ||
'x_0_crit1', | ||
'r_6_5', | ||
'x_5_|', | ||
'l_5_1', | ||
'r_5_0', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from tests.fixtures.results import result_length_7, converter | ||
|
||
def test_tree_str(result_length_7, converter): | ||
tree = converter.build(length = 7, true_nodes = result_length_7) | ||
assert tree == 'U(!(F(&(crit2,crit1))),|(crit2,crit1))' |