Skip to content

Commit

Permalink
Fix issue that prevented handling rule application for rules of the f…
Browse files Browse the repository at this point in the history
…orm pat->Condition[expr_,cond]
  • Loading branch information
mmatera committed Nov 16, 2022
1 parent 2ffdf02 commit 62cf433
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 23 deletions.
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ Bugs

# ``0`` with a given precision (like in ```0`3```) is now parsed as ``0``, an integer number.
#. ``RandomSample`` with one list argument now returns a random ordering of the list items. Previously it would return just one item.

#. Rules of the form ``pat->Condition[expr, cond]`` are handled as in WL. The same also works for nested `Condition` expressions. In particular, the comparison between two Rules with the same pattern but an iterated ``Condition`` expressionare considered equal if the conditions are the same.


Enhancements
++++++++++++
Expand Down
14 changes: 12 additions & 2 deletions mathics/builtin/assignments/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,21 @@ class SetDelayed(Set):
'Condition' ('/;') can be used with 'SetDelayed' to make an
assignment that only holds if a condition is satisfied:
>> f[x_] := p[x] /; x>0
>> f[x_] := p[-x]/; x<-2
>> f[3]
= p[3]
>> f[-3]
= f[-3]
It also works if the condition is set in the LHS:
= p[3]
>> f[-1]
= f[-1]
Notice that the LHS is the same in both definitions, but the second
does not overwrite the first one.
To overwrite one of these definitions, we have to assign using the same condition:
>> f[x_] := Sin[x] /; x>0
>> f[3]
= Sin[3]
In a similar way, the condition can be set in the LHS:
>> F[x_, y_] /; x < y /; x>0 := x / y;
>> F[x_, y_] := y / x;
>> F[2, 3]
Expand Down
13 changes: 10 additions & 3 deletions mathics/builtin/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def create_rules(rules_expr, expr, name, evaluation, extra_args=[]):
else:
result = []
for rule in rules:
if rule.get_head_name() not in ("System`Rule", "System`RuleDelayed"):
head_name = rule.get_head_name()
if head_name not in ("System`Rule", "System`RuleDelayed"):
evaluation.message(name, "reps", rule)
return None, True
elif len(rule.elements) != 2:
Expand All @@ -186,7 +187,13 @@ def create_rules(rules_expr, expr, name, evaluation, extra_args=[]):
)
return None, True
else:
result.append(Rule(rule.elements[0], rule.elements[1]))
result.append(
Rule(
rule.elements[0],
rule.elements[1],
delayed=(head_name == "System`RuleDelayed"),
)
)
return result, False


Expand Down Expand Up @@ -1690,7 +1697,7 @@ def __init__(self, rulelist, evaluation):
self._elements = None
self._head = SymbolDispatch

def get_sort_key(self) -> tuple:
def get_sort_key(self, pattern_sort=False) -> tuple:
return self.src.get_sort_key()

def get_atom_name(self):
Expand Down
45 changes: 33 additions & 12 deletions mathics/core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

from typing import List, Optional

from mathics.core.atoms import String
from mathics.core.atoms import Integer, String
from mathics.core.attributes import A_NO_ATTRIBUTES
from mathics.core.convert.expression import to_mathics_list
from mathics.core.element import fully_qualified_symbol_name
from mathics.core.expression import Expression
from mathics.core.pattern import Pattern
from mathics.core.rules import Rule
from mathics.core.symbols import (
Atom,
Symbol,
Expand Down Expand Up @@ -721,9 +723,6 @@ def get_ownvalue(self, name):
return None

def set_ownvalue(self, name, value) -> None:
from .expression import Symbol
from .rules import Rule

name = self.lookup_name(name)
self.add_rule(name, Rule(Symbol(name), value))
self.clear_cache(name)
Expand Down Expand Up @@ -759,8 +758,6 @@ def get_config_value(self, name, default=None):
return default

def set_config_value(self, name, new_value) -> None:
from mathics.core.expression import Integer

self.set_ownvalue(name, Integer(new_value))

def set_line_no(self, line_no) -> None:
Expand All @@ -780,6 +777,25 @@ def get_history_length(self):


def get_tag_position(pattern, name) -> Optional[str]:
# Strip first the pattern from HoldPattern, Pattern
# and Condition wrappings
while True:
# TODO: Not Atom/Expression,
# pattern -> pattern.to_expression()
if isinstance(pattern, Pattern):
pattern = pattern.expr
continue
if pattern.has_form("System`HoldPattern", 1):
pattern = pattern.elements[0]
continue
if pattern.has_form("System`Pattern", 2):
pattern = pattern.elements[1]
continue
if pattern.has_form("System`Condition", 2):
pattern = pattern.elements[0]
continue
break

if pattern.get_name() == name:
return "own"
elif isinstance(pattern, Atom):
Expand All @@ -788,10 +804,8 @@ def get_tag_position(pattern, name) -> Optional[str]:
head_name = pattern.get_head_name()
if head_name == name:
return "down"
elif head_name == "System`N" and len(pattern.elements) == 2:
elif pattern.has_form("System`N", 2):
return "n"
elif head_name == "System`Condition" and len(pattern.elements) > 0:
return get_tag_position(pattern.elements[0], name)
elif pattern.get_lookup_name() == name:
return "sub"
else:
Expand All @@ -801,11 +815,18 @@ def get_tag_position(pattern, name) -> Optional[str]:
return None


def insert_rule(values, rule) -> None:
def insert_rule(values: list, rule: Rule) -> None:
rhs_conds = getattr(rule, "rhs_conditions", [])
for index, existing in enumerate(values):
if existing.pattern.sameQ(rule.pattern):
del values[index]
break
# Check for coincidences in the replace conditions,
# it they are there.
# This ensures that the rules are equivalent even taking
# into accound the RHS conditions.
existing_rhs_conds = getattr(existing, "rhs_conditions", [])
if existing_rhs_conds == rhs_conds:
del values[index]
break
# use insort_left to guarantee that if equal rules exist, newer rules will
# get higher precedence by being inserted before them. see DownValues[].
bisect.insort_left(values, rule)
Expand Down
143 changes: 138 additions & 5 deletions mathics/core/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mathics.core.element import KeyComparable
from mathics.core.expression import Expression
from mathics.core.symbols import strip_context
from mathics.core.symbols import strip_context, SymbolTrue
from mathics.core.pattern import Pattern, StopGenerator

from itertools import chain
Expand All @@ -19,6 +19,10 @@ def function_arguments(f):
return _python_function_arguments(f)


class StopMatchConditionFailed(StopGenerator):
pass


class StopGenerator_BaseRule(StopGenerator):
pass

Expand Down Expand Up @@ -59,7 +63,11 @@ def yield_match(vars, rest):
if name.startswith("_option_"):
options[name[len("_option_") :]] = value
del vars[name]
new_expression = self.do_replace(expression, vars, options, evaluation)
try:
new_expression = self.do_replace(expression, vars, options, evaluation)
except StopMatchConditionFailed:
return

if new_expression is None:
new_expression = expression
if rest[0] or rest[1]:
Expand Down Expand Up @@ -107,7 +115,7 @@ def yield_match(vars, rest):
def do_replace(self):
raise NotImplementedError

def get_sort_key(self) -> tuple:
def get_sort_key(self, pattern_sort=False) -> tuple:
# FIXME: check if this makes sense:
return tuple((self.system, self.pattern.get_sort_key(True)))

Expand All @@ -131,12 +139,131 @@ class Rule(BaseRule):
``G[1.^2, a^2]``
"""

def __init__(self, pattern, replace, system=False) -> None:
def __ge__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys > sys_other
if key != key_other:
return key > key_other

# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other > len_cond
if len_cond == 0:
return False
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk > o_sk:
return False
return True
# Follow the usual rule
return self.get_sort_key(True) >= other.get_sort_key(True)

def __gt__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys > sys_other
if key != key_other:
return key > key_other

# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other > len_cond
if len_cond == 0:
return False

for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk > o_sk:
return False
return me_sk > o_sk
# Follow the usual rule
return self.get_sort_key(True) > other.get_sort_key(True)

def __le__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys < sys_other
if key != key_other:
return key < key_other

# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other < len_cond
if len_cond == 0:
return False
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk < o_sk:
return False
return True
# Follow the usual rule
return self.get_sort_key(True) <= other.get_sort_key(True)

def __lt__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys < sys_other
if key != key_other:
return key < key_other

# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other < len_cond
if len_cond == 0:
return False

for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk < o_sk:
return False
return me_sk > o_sk
# Follow the usual rule
return self.get_sort_key(True) < other.get_sort_key(True)

def __init__(self, pattern, replace, delayed=True, system=False) -> None:
super(Rule, self).__init__(pattern, system=system)
self.replace = replace
self.delayed = delayed
# If delayed is True, and replace is a nested
# Condition expression, stores the conditions and the
# remaining stripped expression.
# This is going to be used to compare and sort rules,
# and also to decide if the rule matches an expression.
conds = []
if delayed:
while replace.has_form("System`Condition", 2):
replace, cond = replace.elements
conds.append(cond)

self.rhs_conditions = sorted(conds)
self.strip_replace = replace

def do_replace(self, expression, vars, options, evaluation):
new = self.replace.replace_vars(vars)
replace = self.replace if self.rhs_conditions == [] else self.strip_replace
for cond in self.rhs_conditions:
cond = cond.replace_vars(vars)
cond = cond.evaluate(evaluation)
if cond is not SymbolTrue:
raise StopMatchConditionFailed

new = replace.replace_vars(vars)
new.options = options

# if options is a non-empty dict, we need to ensure reevaluation of the whole expression, since 'new' will
Expand All @@ -159,6 +286,12 @@ def do_replace(self, expression, vars, options, evaluation):
def __repr__(self) -> str:
return "<Rule: %s -> %s>" % (self.pattern, self.replace)

def get_sort_key(self, pattern_sort=False) -> tuple:
# FIXME: check if this makes sense:
return tuple(
(self.system, self.pattern.get_sort_key(True), self.rhs_conditions)
)


class BuiltinRule(BaseRule):
"""
Expand Down
5 changes: 5 additions & 0 deletions mathics/core/systemsymbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
SymbolSeries = Symbol("System`Series")
SymbolSeriesData = Symbol("System`SeriesData")
SymbolSet = Symbol("System`Set")
SymbolSetDelayed = Symbol("System`SetDelayed")
SymbolSign = Symbol("System`Sign")
SymbolSimplify = Symbol("System`Simplify")
SymbolSin = Symbol("System`Sin")
Expand All @@ -186,6 +187,8 @@
SymbolSubsuperscriptBox = Symbol("System`SubsuperscriptBox")
SymbolSuperscriptBox = Symbol("System`SuperscriptBox")
SymbolTable = Symbol("System`Table")
SymbolTagSet = Symbol("System`TagSet")
SymbolTagSetDelayed = Symbol("System`TagSetDelayed")
SymbolTeXForm = Symbol("System`TeXForm")
SymbolThrow = Symbol("System`Throw")
SymbolToString = Symbol("System`ToString")
Expand All @@ -194,5 +197,7 @@
SymbolUndefined = Symbol("System`Undefined")
SymbolUnequal = Symbol("System`Unequal")
SymbolUnevaluated = Symbol("System`Unevaluated")
SymbolUpSet = Symbol("System`UpSet")
SymbolUpSetDelayed = Symbol("System`UpSetDelayed")
SymbolUpValues = Symbol("System`UpValues")
SymbolXor = Symbol("System`Xor")
Loading

0 comments on commit 62cf433

Please sign in to comment.