diff --git a/pyk/src/pyk/k2lean4/Prelude.lean b/pyk/src/pyk/k2lean4/Prelude.lean index d8b5de0e97..a7d4e04c63 100644 --- a/pyk/src/pyk/k2lean4/Prelude.lean +++ b/pyk/src/pyk/k2lean4/Prelude.lean @@ -20,7 +20,7 @@ These theorems should be provable directly from the function rules and the seman -/ -- Basic K types -abbrev SortBool : Type := Int +abbrev SortBool : Type := Bool abbrev SortBytes : Type := ByteArray abbrev SortId : Type := String abbrev SortInt : Type := Int diff --git a/pyk/src/pyk/k2lean4/k2lean4.py b/pyk/src/pyk/k2lean4/k2lean4.py index b05f0a02f9..5a2247f366 100644 --- a/pyk/src/pyk/k2lean4/k2lean4.py +++ b/pyk/src/pyk/k2lean4/k2lean4.py @@ -2,14 +2,18 @@ import re from dataclasses import dataclass +from functools import cached_property from graphlib import TopologicalSorter from itertools import count -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple +from ..dequote import bytes_encode from ..konvert import unmunge from ..kore.internal import CollectionKind -from ..kore.syntax import SortApp -from ..utils import POSet +from ..kore.kompiled import KoreSymbolTable +from ..kore.manip import elim_aliases, free_occs +from ..kore.syntax import DV, And, App, EVar, SortApp, String, Top +from ..utils import FrozenDict, POSet from .model import ( Alt, AltsFieldVal, @@ -31,10 +35,12 @@ ) if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping from typing import Final from ..kore.internal import KoreDefn - from ..kore.syntax import SymbolDecl + from ..kore.rule import RewriteRule + from ..kore.syntax import Pattern, Sort, SymbolDecl from .model import Binder, Command, Declaration, FieldVal @@ -45,10 +51,99 @@ _PRELUDE_SORTS: Final = {'SortBool', 'SortBytes', 'SortId', 'SortInt', 'SortString', 'SortStringBuffer'} +class Field(NamedTuple): + name: str + ty: Term + + @dataclass(frozen=True) class K2Lean4: defn: KoreDefn + @cached_property + def symbol_table(self) -> KoreSymbolTable: + return KoreSymbolTable(self.defn.symbols.values()) + + @cached_property + def structure_symbols(self) -> FrozenDict[str, str]: + def constructed_by(symbol: str) -> str | None: + decl = self.defn.symbols[symbol] + _sort = decl.sort + + if not isinstance(_sort, SortApp): + return None + + sort = _sort.name + + if not self._is_cell(sort) and not self._is_collection(sort): + return None + + if symbol not in self.defn.constructors.get(sort, ()): + return None + + return sort + + return FrozenDict( + (symbol, sort) for symbol in self.defn.symbols if (sort := constructed_by(symbol)) is not None + ) + + @cached_property + def structures(self) -> FrozenDict[str, tuple[Field, ...]]: + def fields_of(sort: str) -> tuple[Field, ...] | None: + if self._is_cell(sort): + return self._cell_fields(sort) + + if self._is_collection(sort): + return (self._collection_field(sort),) + + return None + + return FrozenDict((sort, fields) for sort in self.defn.sorts if (fields := fields_of(sort)) is not None) + + @staticmethod + def _is_cell(sort: str) -> bool: + return sort.endswith('Cell') + + def _cell_fields(self, sort: str) -> tuple[Field, ...]: + (ctor,) = self.defn.constructors[sort] + decl = self.defn.symbols[ctor] + sorts = _param_sorts(decl) + + names: list[str] + if all(self._is_cell(sort) for sort in sorts): + names = [] + for sort in sorts: + assert sort.startswith('Sort') + assert sort.endswith('Cell') + name = sort[4:-4] + name = name[0].lower() + name[1:] + names.append(name) + else: + assert len(sorts) == 1 + names = ['val'] + + return tuple(Field(name, Term(sort)) for name, sort in zip(names, sorts, strict=True)) + + def _is_collection(self, sort: str) -> bool: + return sort in self.defn.collections + + def _collection_field(self, sort: str) -> Field: + coll = self.defn.collections[sort] + elem = self.defn.symbols[coll.element] + sorts = _param_sorts(elem) + term: Term + match coll.kind: + case CollectionKind.LIST: + (item,) = sorts + term = Term(f'(ListHook {item}).list') + case CollectionKind.SET: + (item,) = sorts + term = Term(f'(SetHook {item}).set') + case CollectionKind.MAP: + key, value = sorts + term = Term(f'(MapHook {key} {value}).map') + return Field('coll', term) + def sort_module(self) -> Module: commands: tuple[Command, ...] = tuple( block for sorts in _ordered_sorts(self.defn) if (block := self._sort_block(sorts)) is not None @@ -75,17 +170,11 @@ def is_inductive(sort: str) -> bool: decl = self.defn.sorts[sort] return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key and not self._is_cell(sort) - def is_collection(sort: str) -> bool: - return sort in self.defn.collections - if is_inductive(sort): return self._inductive(sort) - if self._is_cell(sort): - return self._cell(sort) - - if is_collection(sort): - return self._collection(sort) + if sort in self.structures: + return self._structure(sort) raise AssertionError @@ -116,49 +205,10 @@ def _symbol_ident(symbol: str) -> str: symbol = f'«{symbol}»' return symbol - @staticmethod - def _is_cell(sort: str) -> bool: - return sort.endswith('Cell') - - def _cell(self, sort: str) -> Structure: - (cell_ctor,) = self.defn.constructors[sort] - decl = self.defn.symbols[cell_ctor] - param_sorts = _param_sorts(decl) - - param_names: list[str] - - if all(self._is_cell(sort) for sort in param_sorts): - param_names = [] - for param_sort in param_sorts: - assert param_sort.startswith('Sort') - assert param_sort.endswith('Cell') - name = param_sort[4:-4] - name = name[0].lower() + name[1:] - param_names.append(name) - else: - assert len(param_sorts) == 1 - param_names = ['val'] - - fields = tuple(ExplBinder((name,), Term(sort)) for name, sort in zip(param_names, param_sorts, strict=True)) - return Structure(sort, Signature((), Term('Type')), ctor=StructCtor(fields)) - - def _collection(self, sort: str) -> Structure: - coll = self.defn.collections[sort] - elem = self.defn.symbols[coll.element] - sorts = _param_sorts(elem) - val: Term - match coll.kind: - case CollectionKind.LIST: - (item,) = sorts - val = Term(f'(ListHook {item}).list') - case CollectionKind.SET: - (item,) = sorts - val = Term(f'(SetHook {item}).set') - case CollectionKind.MAP: - key, value = sorts - val = Term(f'(MapHook {key} {value}).map') - field = ExplBinder(('coll',), val) - return Structure(sort, Signature((), Term('Type')), ctor=StructCtor((field,))) + def _structure(self, sort: str) -> Structure: + fields = self.structures[sort] + binders = tuple(ExplBinder((name,), ty) for name, ty in fields) + return Structure(sort, Signature((), Term('Type')), ctor=StructCtor(binders)) def inj_module(self) -> Module: return Module(commands=self._inj_commands()) @@ -226,6 +276,224 @@ def _transform_func(self, func: str) -> Axiom: binders.extend(ExplBinder((f'x{i}',), Term(sort)) for i, sort in enumerate(param_sorts)) return Axiom(ident, Signature(binders, Term(f'Option {sort}'))) + def rewrite_module(self) -> Module: + commands = (self._rewrite_inductive(),) + return Module(commands=commands) + + def _rewrite_inductive(self) -> Inductive: + def tran_ctor() -> Ctor: + return Ctor( + 'tran', + Signature( + ( + ImplBinder(('s1', 's2', 's3'), Term('SortGeneratedTopCell')), + ExplBinder(('t1',), Term('Rewrites s1 s2')), + ExplBinder(('t2',), Term('Rewrites s2 s3')), + ), + Term('Rewrites s1 s3'), + ), + ) + + ctors: list[Ctor] = [] + ctors.append(tran_ctor()) + ctors.extend(self._rewrite_ctors()) + signature = Signature(ty=Term('SortGeneratedTopCell → SortGeneratedTopCell → Prop')) + return Inductive('Rewrites', signature, ctors=ctors) + + def _rewrite_ctors(self) -> list[Ctor]: + rewrites = sorted(self.defn.rewrites, key=self._rewrite_name) + return [self._rewrite_ctor(rule) for rule in rewrites] + + def _rewrite_ctor(self, rule: RewriteRule) -> Ctor: + req = rule.req if rule.req else Top(SortApp('Foo')) + + # Step 1: eliminate aliases + pattern = elim_aliases(And(SortApp('Foo'), (req, rule.lhs, rule.rhs))) + + # Step 2: eliminate function application + free = (f"Var'Unds'Val{i}" for i in count()) + pattern, defs = self._elim_fun_apps(pattern, free) + + # Step 3: create binders + binders: list[Binder] = [] + binders.extend(self._free_binders(pattern)) # Binders of the form {x y : SortInt} + binders.extend(self._def_binders(defs)) # Binders of the form (def_y : foo x = some y) + + # Step 4: transform patterns + assert isinstance(pattern, And) + req, lhs, rhs = pattern.ops + + if not isinstance(req, Top): + req_term = self._transform_pattern(req) + binders.append(ExplBinder(('req',), Term(f'{req_term} = true'))) + + lhs_term = self._transform_pattern(lhs) + rhs_term = self._transform_pattern(rhs) + return Ctor(self._rewrite_name(rule), Signature(binders, Term(f'Rewrites {lhs_term} {rhs_term}'))) + + @staticmethod + def _rewrite_name(rule: RewriteRule) -> str: + if rule.label: + return rule.label.replace('-', '_').replace('.', '_') + return f'_{rule.uid[:7]}' + + @staticmethod + def _var_ident(name: str) -> str: + assert name.startswith('Var') + return K2Lean4._symbol_ident(name[3:]) + + def _elim_fun_apps(self, pattern: Pattern, free: Iterator[str]) -> tuple[Pattern, dict[str, Pattern]]: + """Replace ``foo(bar(x))`` with ``z`` and return mapping ``{y: bar(x), z: foo(y)}`` with ``y``, ``z`` fresh variables.""" + defs: dict[str, Pattern] = {} + + def abstract_funcs(pattern: Pattern) -> Pattern: + if isinstance(pattern, App) and pattern.symbol in self.defn.functions: + name = next(free) + ident = self._var_ident(name) + defs[ident] = pattern + sort = self.symbol_table.infer_sort(pattern) + return EVar(name, sort) + return pattern + + return pattern.bottom_up(abstract_funcs), defs + + def _free_binders(self, pattern: Pattern) -> list[Binder]: + free_vars = {occ for _, occs in free_occs(pattern).items() for occ in occs} + grouped_vars: dict[str, set[str]] = {} + for var in free_vars: + match var: + case EVar(name, SortApp(sort)): + ident = self._var_ident(name) + assert ident not in grouped_vars.get(sort, ()) + grouped_vars.setdefault(sort, set()).add(ident) + case _: + raise AssertionError() + sorted_vars: dict[str, list[str]] = dict( + sorted(((sort, sorted(idents)) for sort, idents in grouped_vars.items()), key=lambda item: item[1][0]) + ) + return [ImplBinder(idents, Term(sort)) for sort, idents in sorted_vars.items()] + + def _def_binders(self, defs: Mapping[str, Pattern]) -> list[Binder]: + return [ + ExplBinder((f'defn{ident}',), Term(f'{self._transform_pattern(pattern)} = some {ident}')) + for ident, pattern in defs.items() + ] + + def _transform_pattern(self, pattern: Pattern) -> Term: + match pattern: + case EVar(name): + return self._transform_evar(name) + case DV(SortApp(sort), String(value)): + return self._transform_dv(sort, value) + case App(symbol, sorts, args): + return self._transform_app(symbol, sorts, args) + case _: + raise ValueError(f'Unsupported pattern: {pattern.text}') + + def _transform_evar(self, name: str) -> Term: + return Term(self._var_ident(name)) + + def _transform_dv(self, sort: str, value: str) -> Term: + match sort: + case 'SortBool': + return Term(value) + case 'SortInt': + return self._transform_int_dv(value) + case 'SortBytes': + return self._transform_bytes_dv(value) + case 'SortId' | 'SortString' | 'SortStringBuffer': + return self._transform_string_dv(value) + case _: + raise ValueError(f'Unsupported sort: {sort}') + + def _transform_int_dv(self, value: str) -> Term: + val = int(value) + return Term(str(val)) if val >= 0 else Term(f'({val})') + + def _transform_bytes_dv(self, value: str) -> Term: + bytes_str = ', '.join(f'0x{byte:02X}' for byte in bytes_encode(value)) + return Term(f'⟨#[{bytes_str}⟩]') + + def _transform_string_dv(self, value: str) -> Term: + escapes = { + ord('\r'): r'\r', + ord('\n'): r'\n', + ord('\t'): r'\t', + ord('\\'): r'\\', + ord('"'): r'\"', + ord("'"): r"\'", + } + + def encode(c: str) -> str: + code = ord(c) + if code in escapes: + return escapes[code] + elif 32 <= code < 127: + return c + elif code <= 0xFF: + return fr'\x{code:02x}' + elif code <= 0xFFFF: + return fr'\u{code:04x}' + else: + raise ValueError(f"Unsupported character: '{c}' ({code})") + + encoded = ''.join(encode(c) for c in value) + return Term(f'"{encoded}"') + + def _transform_app(self, symbol: str, sorts: tuple[Sort, ...], args: tuple[Pattern, ...]) -> Term: + if symbol == 'inj': + return self._transform_inj_app(sorts, args) + + if symbol in self.structure_symbols: + fields = self.structures[self.structure_symbols[symbol]] + return self._transform_structure_app(fields, args) + + decl = self.defn.symbols[symbol] + sort = decl.sort.name if isinstance(decl.sort, SortApp) else None + return self._transform_basic_app(sort, symbol, args) + + def _transform_arg(self, pattern: Pattern) -> Term: + term = self._transform_pattern(pattern) + + if not isinstance(pattern, App): + return term + + if pattern.symbol in self.structure_symbols: + return term + + return Term(f'({term})') + + def _transform_inj_app(self, sorts: tuple[Sort, ...], args: tuple[Pattern, ...]) -> Term: + _from_sort, _to_sort = sorts + assert isinstance(_from_sort, SortApp) + assert isinstance(_to_sort, SortApp) + from_str = _from_sort.name + to_str = _to_sort.name + (arg,) = args + term = self._transform_arg(arg) + return Term(f'(@inj {from_str} {to_str}) {term}') + + def _transform_structure_app(self, fields: Iterable[Field], args: Iterable[Pattern]) -> Term: + fields_str = ', '.join( + f'{field.name} := {self._transform_pattern(arg)}' for field, arg in zip(fields, args, strict=True) + ) + lbrace, rbrace = ['{', '}'] + return Term(f'{lbrace} {fields_str} {rbrace}') + + def _transform_basic_app(self, sort: str | None, symbol: str, args: Iterable[Pattern]) -> Term: + chunks = [] + + ident: str + if sort and symbol in self.defn.constructors.get(sort, ()): + # Symbol is a constructor + ident = f'{sort}.{self._symbol_ident(symbol)}' + else: + ident = self._symbol_ident(symbol) + + chunks.append(ident) + chunks.extend(str(self._transform_arg(arg)) for arg in args) + return Term(' '.join(chunks)) + def _param_sorts(decl: SymbolDecl) -> list[str]: from ..utils import check_type diff --git a/pyk/src/pyk/kore/internal.py b/pyk/src/pyk/kore/internal.py index 588bcf35b3..cfafb1e75a 100644 --- a/pyk/src/pyk/kore/internal.py +++ b/pyk/src/pyk/kore/internal.py @@ -86,6 +86,8 @@ def from_definition(defn: Definition) -> KoreDefn: sorts[name] = sent case SymbolDecl(Symbol(name)): symbols[name] = sent + if 'function' in sent.attrs_by_key: + functions.setdefault(name, []) case Axiom(attrs=(App('subsort', (SortApp(subsort), SortApp(supersort))),)): subsorts.append((subsort, supersort)) case Axiom(): diff --git a/pyk/src/pyk/kore/manip.py b/pyk/src/pyk/kore/manip.py index affe002ab5..effcd2bf40 100644 --- a/pyk/src/pyk/kore/manip.py +++ b/pyk/src/pyk/kore/manip.py @@ -55,3 +55,30 @@ def add_symbol(pattern: Pattern) -> None: pattern.collect(add_symbol) return res + + +def elim_aliases(pattern: Pattern) -> Pattern: + r"""Eliminate subpatterns of the form ``\and{S}(p, X : S)``. + + Both the ``\and`` and instances of ``X : S`` are replaced by the definition ``p``. + """ + aliases = {} + + def inline_aliases(pattern: Pattern) -> Pattern: + match pattern: + case And(_, (p, EVar(name))): + aliases[name] = p + return p + case _: + return pattern + + def substitute_vars(pattern: Pattern) -> Pattern: + match pattern: + case EVar(name) as var: + return aliases.get(name, var) + case _: + return pattern + + pattern = pattern.bottom_up(inline_aliases) + pattern = pattern.bottom_up(substitute_vars) + return pattern