Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pat->Condition[...] apply rules #621

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Bugs

# ``0`` with a given precision (like in ```0`3```) is now parsed as ``0``, an integer number.
#. ``RandomSample`` with one list argument now returns a random ordering of the list items. Previously it would return just one item.
#. Rules of the form ``pat->Condition[expr, cond]`` are handled as in WL. The same also works for nested `Condition` expressions. In particular, the comparison between two Rules with the same pattern but an iterated ``Condition`` expressionare considered equal if the conditions are the same.
#. Origin placement corrected on ``ListPlot`` and ``LinePlot``.
#. Fix long-standing bugs in Image handling

Expand Down
14 changes: 12 additions & 2 deletions mathics/builtin/assignments/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,21 @@ class SetDelayed(Set):
'Condition' ('/;') can be used with 'SetDelayed' to make an
assignment that only holds if a condition is satisfied:
>> f[x_] := p[x] /; x>0
>> f[x_] := p[-x]/; x<-2
>> f[3]
= p[3]
>> f[-3]
= f[-3]
It also works if the condition is set in the LHS:
= p[3]
>> f[-1]
= f[-1]
Notice that the LHS is the same in both definitions, but the second
does not overwrite the first one.

To overwrite one of these definitions, we have to assign using the same condition:
>> f[x_] := Sin[x] /; x>0
>> f[3]
= Sin[3]
In a similar way, the condition can be set in the LHS:
>> F[x_, y_] /; x < y /; x>0 := x / y;
>> F[x_, y_] := y / x;
>> F[2, 3]
Expand Down
13 changes: 10 additions & 3 deletions mathics/builtin/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def create_rules(rules_expr, expr, name, evaluation, extra_args=[]):
else:
result = []
for rule in rules:
if rule.get_head_name() not in ("System`Rule", "System`RuleDelayed"):
head_name = rule.get_head_name()
if head_name not in ("System`Rule", "System`RuleDelayed"):
evaluation.message(name, "reps", rule)
return None, True
elif len(rule.elements) != 2:
Expand All @@ -177,7 +178,13 @@ def create_rules(rules_expr, expr, name, evaluation, extra_args=[]):
)
return None, True
else:
result.append(Rule(rule.elements[0], rule.elements[1]))
result.append(
Rule(
rule.elements[0],
rule.elements[1],
delayed=(head_name == "System`RuleDelayed"),
)
)
return result, False


Expand Down Expand Up @@ -1706,7 +1713,7 @@ def __init__(self, rulelist, evaluation):
self._elements = None
self._head = SymbolDispatch

def get_sort_key(self) -> tuple:
def get_sort_key(self, pattern_sort=False) -> tuple:
return self.src.get_sort_key()

def get_atom_name(self):
Expand Down
45 changes: 33 additions & 12 deletions mathics/core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

from mathics_scanner.tokeniser import full_names_pattern

from mathics.core.atoms import String
from mathics.core.atoms import Integer, String
from mathics.core.attributes import A_NO_ATTRIBUTES
from mathics.core.convert.expression import to_mathics_list
from mathics.core.element import fully_qualified_symbol_name
from mathics.core.expression import Expression
from mathics.core.pattern import Pattern
from mathics.core.rules import Rule
from mathics.core.symbols import Atom, Symbol, strip_context
from mathics.core.systemsymbols import SymbolGet

Expand Down Expand Up @@ -642,9 +644,6 @@ def get_ownvalue(self, name):
return None

def set_ownvalue(self, name, value) -> None:
from .expression import Symbol
from .rules import Rule

name = self.lookup_name(name)
self.add_rule(name, Rule(Symbol(name), value))
self.clear_cache(name)
Expand Down Expand Up @@ -680,8 +679,6 @@ def get_config_value(self, name, default=None):
return default

def set_config_value(self, name, new_value) -> None:
from mathics.core.expression import Integer

self.set_ownvalue(name, Integer(new_value))

def set_line_no(self, line_no) -> None:
Expand All @@ -701,6 +698,25 @@ def get_history_length(self):


def get_tag_position(pattern, name) -> Optional[str]:
# Strip first the pattern from HoldPattern, Pattern
# and Condition wrappings
while True:
# TODO: Not Atom/Expression,
# pattern -> pattern.to_expression()
if isinstance(pattern, Pattern):
pattern = pattern.expr
continue
if pattern.has_form("System`HoldPattern", 1):
pattern = pattern.elements[0]
continue
if pattern.has_form("System`Pattern", 2):
pattern = pattern.elements[1]
continue
if pattern.has_form("System`Condition", 2):
pattern = pattern.elements[0]
continue
break

if pattern.get_name() == name:
return "own"
elif isinstance(pattern, Atom):
Expand All @@ -709,10 +725,8 @@ def get_tag_position(pattern, name) -> Optional[str]:
head_name = pattern.get_head_name()
if head_name == name:
return "down"
elif head_name == "System`N" and len(pattern.elements) == 2:
elif pattern.has_form("System`N", 2):
return "n"
elif head_name == "System`Condition" and len(pattern.elements) > 0:
return get_tag_position(pattern.elements[0], name)
elif pattern.get_lookup_name() == name:
return "sub"
else:
Expand All @@ -722,11 +736,18 @@ def get_tag_position(pattern, name) -> Optional[str]:
return None


def insert_rule(values, rule) -> None:
def insert_rule(values: list, rule: Rule) -> None:
rhs_conds = getattr(rule, "rhs_conditions", [])
for index, existing in enumerate(values):
if existing.pattern.sameQ(rule.pattern):
del values[index]
break
# Check for coincidences in the replace conditions,
# it they are there.
# This ensures that the rules are equivalent even taking
# into accound the RHS conditions.
existing_rhs_conds = getattr(existing, "rhs_conditions", [])
if existing_rhs_conds == rhs_conds:
del values[index]
break
# use insort_left to guarantee that if equal rules exist, newer rules will
# get higher precedence by being inserted before them. see DownValues[].
bisect.insort_left(values, rule)
Expand Down
143 changes: 138 additions & 5 deletions mathics/core/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mathics.core.element import KeyComparable
from mathics.core.expression import Expression
from mathics.core.pattern import Pattern, StopGenerator
from mathics.core.symbols import strip_context
from mathics.core.symbols import SymbolTrue, strip_context


def _python_function_arguments(f):
Expand All @@ -18,6 +18,10 @@ def function_arguments(f):
return _python_function_arguments(f)


class StopMatchConditionFailed(StopGenerator):
pass


class StopGenerator_BaseRule(StopGenerator):
pass

Expand Down Expand Up @@ -58,7 +62,11 @@ def yield_match(vars, rest):
if name.startswith("_option_"):
options[name[len("_option_") :]] = value
del vars[name]
new_expression = self.do_replace(expression, vars, options, evaluation)
try:
new_expression = self.do_replace(expression, vars, options, evaluation)
except StopMatchConditionFailed:
return

if new_expression is None:
new_expression = expression
if rest[0] or rest[1]:
Expand Down Expand Up @@ -106,7 +114,7 @@ def yield_match(vars, rest):
def do_replace(self):
raise NotImplementedError

def get_sort_key(self) -> tuple:
def get_sort_key(self, pattern_sort=False) -> tuple:
# FIXME: check if this makes sense:
return tuple((self.system, self.pattern.get_sort_key(True)))

Expand All @@ -130,12 +138,131 @@ class Rule(BaseRule):
``G[1.^2, a^2]``
"""

def __init__(self, pattern, replace, system=False) -> None:
def __ge__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys > sys_other
if key != key_other:
return key > key_other

# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other > len_cond
if len_cond == 0:
return False
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk > o_sk:
return False
return True
# Follow the usual rule
return self.get_sort_key(True) >= other.get_sort_key(True)

def __gt__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys > sys_other
if key != key_other:
return key > key_other

# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other > len_cond
if len_cond == 0:
return False

for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk > o_sk:
return False
return me_sk > o_sk
# Follow the usual rule
return self.get_sort_key(True) > other.get_sort_key(True)

def __le__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys < sys_other
if key != key_other:
return key < key_other

# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other < len_cond
if len_cond == 0:
return False
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk < o_sk:
return False
return True
# Follow the usual rule
return self.get_sort_key(True) <= other.get_sort_key(True)

def __lt__(self, other):
if isinstance(other, Rule):
sys, key, rhs_cond = self.get_sort_key()
sys_other, key_other, rhs_cond_other = other.get_sort_key()
if sys != sys_other:
return sys < sys_other
if key != key_other:
return key < key_other

# larger and more complex conditions come first
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
if len_cond != len_cond_other:
return len_cond_other < len_cond
if len_cond == 0:
return False

for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
me_sk = me_cond.get_sort_key(True)
o_sk = other_cond.get_sort_key(True)
if me_sk < o_sk:
return False
return me_sk > o_sk
# Follow the usual rule
return self.get_sort_key(True) < other.get_sort_key(True)

def __init__(self, pattern, replace, delayed=True, system=False) -> None:
super(Rule, self).__init__(pattern, system=system)
self.replace = replace
self.delayed = delayed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This property helps to distinguish if the Rule comes from a Rule expression or a RuleDelayed expression, and it is needed to handle the different behavior observed in WMA. I do not see a reason for this different behavior, but maybe it is, and maybe there is a different behavior in other contexts.

# If delayed is True, and replace is a nested
# Condition expression, stores the conditions and the
# remaining stripped expression.
# This is going to be used to compare and sort rules,
# and also to decide if the rule matches an expression.
conds = []
if delayed:
while replace.has_form("System`Condition", 2):
replace, cond = replace.elements
conds.append(cond)

self.rhs_conditions = sorted(conds)
self.strip_replace = replace

def do_replace(self, expression, vars, options, evaluation):
new = self.replace.replace_vars(vars)
replace = self.replace if self.rhs_conditions == [] else self.strip_replace
for cond in self.rhs_conditions:
cond = cond.replace_vars(vars)
cond = cond.evaluate(evaluation)
if cond is not SymbolTrue:
raise StopMatchConditionFailed

new = replace.replace_vars(vars)
new.options = options

# if options is a non-empty dict, we need to ensure reevaluation of the whole expression, since 'new' will
Expand All @@ -158,6 +285,12 @@ def do_replace(self, expression, vars, options, evaluation):
def __repr__(self) -> str:
return "<Rule: %s -> %s>" % (self.pattern, self.replace)

def get_sort_key(self, pattern_sort=False) -> tuple:
# FIXME: check if this makes sense:
return tuple(
(self.system, self.pattern.get_sort_key(True), self.rhs_conditions)
)


class BuiltinRule(BaseRule):
"""
Expand Down
5 changes: 5 additions & 0 deletions mathics/core/systemsymbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
SymbolSeries = Symbol("System`Series")
SymbolSeriesData = Symbol("System`SeriesData")
SymbolSet = Symbol("System`Set")
SymbolSetDelayed = Symbol("System`SetDelayed")
SymbolSign = Symbol("System`Sign")
SymbolSimplify = Symbol("System`Simplify")
SymbolSin = Symbol("System`Sin")
Expand All @@ -201,6 +202,8 @@
SymbolSubsuperscriptBox = Symbol("System`SubsuperscriptBox")
SymbolSuperscriptBox = Symbol("System`SuperscriptBox")
SymbolTable = Symbol("System`Table")
SymbolTagSet = Symbol("System`TagSet")
SymbolTagSetDelayed = Symbol("System`TagSetDelayed")
SymbolTan = Symbol("System`Tan")
SymbolTanh = Symbol("System`Tanh")
SymbolTeXForm = Symbol("System`TeXForm")
Expand All @@ -212,5 +215,7 @@
SymbolUndefined = Symbol("System`Undefined")
SymbolUnequal = Symbol("System`Unequal")
SymbolUnevaluated = Symbol("System`Unevaluated")
SymbolUpSet = Symbol("System`UpSet")
SymbolUpSetDelayed = Symbol("System`UpSetDelayed")
SymbolUpValues = Symbol("System`UpValues")
SymbolXor = Symbol("System`Xor")
Loading