Skip to content

Commit

Permalink
Remove direct statespace references from symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
pschanely committed Mar 24, 2024
1 parent 751b019 commit e698c34
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 49 deletions.
22 changes: 10 additions & 12 deletions crosshair/copyext_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,17 @@ def test_deepcopyext_symbolic_set():
deepcopyext(s, CopyMode.REALIZE, {})


def test_deepcopyext_realize():
with standalone_statespace, NoTracing():
x = SymbolicInt("x")
lock = RLock()
lockarray = [lock]
input = {"a": ([x], lockarray, lockarray)}
def test_deepcopyext_realize(space):
x = SymbolicInt("x")
lock = RLock()
lockarray = [lock]
input = {"a": ([x], lockarray, lockarray)}
output = deepcopyext(input, CopyMode.REALIZE, {})
with NoTracing():
assert input is not output
assert input["a"] is not output["a"]
assert output["a"][1] is output["a"][2] # memo preserves identity
assert type(input["a"][0][0]) is SymbolicInt
assert type(output["a"][0][0]) is int
assert input is not output
assert input["a"] is not output["a"]
assert output["a"][1] is output["a"][2] # memo preserves identity
assert type(input["a"][0][0]) is SymbolicInt
assert type(output["a"][0][0]) is int


def test_deepcopyext_tuple_type():
Expand Down
7 changes: 3 additions & 4 deletions crosshair/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,10 +1220,9 @@ def dothing(foo: WithUnpickleableArg) -> int:


@pytest.mark.smoke
def test_deep_realize():
with standalone_statespace as space:
x = proxy_for_type(int, "x")
space.add(x.var == 4)
def test_deep_realize(space):
x = proxy_for_type(int, "x")
space.add(x.var == 4)

@dataclasses.dataclass
class Woo:
Expand Down
65 changes: 32 additions & 33 deletions crosshair/libimpl/builtinslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ class SymbolicValue(CrossHairValue):
def __init__(self, smtvar: Union[str, z3.ExprRef], typ: Type):
if is_tracing():
raise CrosshairInternal
self.statespace = context_statespace()
self.snapshot = SnapshotRef(-1)
self.python_type = typ
if type(smtvar) is str:
Expand All @@ -264,7 +263,7 @@ def __init_var__(self, typ, varname):

def __deepcopy__(self, memo):
result = copy.copy(self)
result.snapshot = self.statespace.current_snapshot()
result.snapshot = context_statespace().current_snapshot()
memo[id(self)] = result
return result

Expand Down Expand Up @@ -1007,7 +1006,7 @@ def _smt_promote_literal(cls, literal) -> Optional[z3.SortRef]:

def __ch_realize__(self) -> object:
with NoTracing():
return self.statespace.choose_possible(self.var)
return context_statespace().choose_possible(self.var)

def __neg__(self):
with NoTracing():
Expand All @@ -1025,7 +1024,7 @@ def __index__(self):

def __bool__(self):
with NoTracing():
return self.statespace.choose_possible(self.var)
return context_statespace().choose_possible(self.var)

def __int__(self):
with NoTracing():
Expand Down Expand Up @@ -1068,7 +1067,7 @@ def from_bytes(cls, b: bytes, byteorder: str, signed=False) -> int:
return int.from_bytes(b, byteorder, signed=signed) # type: ignore

def __ch_realize__(self) -> object:
return self.statespace.find_model_value(self.var)
return context_statespace().find_model_value(self.var)

def __repr__(self):
if self < 0:
Expand Down Expand Up @@ -1236,13 +1235,13 @@ def _smt_promote_literal(cls, literal) -> Optional[z3.SortRef]:
return None

def __ch_realize__(self) -> object:
return self.statespace.find_model_value(self.var).__float__() # type: ignore
return context_statespace().find_model_value(self.var).__float__() # type: ignore

def __repr__(self):
return self.statespace.find_model_value(self.var).__repr__()
return context_statespace().find_model_value(self.var).__repr__()

def __hash__(self):
return self.statespace.find_model_value(self.var).__hash__()
return context_statespace().find_model_value(self.var).__hash__()

def __bool__(self):
with NoTracing():
Expand Down Expand Up @@ -1344,7 +1343,7 @@ def __init__(self, smtvar: Union[str, z3.ExprRef], typ: Type):
self.ch_key_type = None
self.smt_key_sort = HeapRef
SymbolicValue.__init__(self, smtvar, typ)
self.statespace.add(self._len() >= 0)
context_statespace().add(self._len() >= 0)

def __ch_realize__(self):
return origin_of(self.python_type)(self)
Expand Down Expand Up @@ -1402,9 +1401,10 @@ def __init_var__(self, typ, varname):
arr_smt_sort = z3.ArraySort(
self.smt_key_sort, possibly_missing_sort(self.smt_val_sort)
)
space = context_statespace()
return (
z3.Const(varname + "_map" + self.statespace.uniq(), arr_smt_sort),
z3.Const(varname + "_len" + self.statespace.uniq(), _SMT_INT_SORT),
z3.Const(varname + "_map" + space.uniq(), arr_smt_sort),
z3.Const(varname + "_len" + space.uniq(), _SMT_INT_SORT),
)

def __eq__(self, other):
Expand Down Expand Up @@ -1455,7 +1455,7 @@ def __getitem__(self, k):
if SymbolicBool(self._len() == 0).__bool__():
raise IgnoreAttempt("SymbolicDict in inconsistent state")
return smt_to_ch_value(
self.statespace,
context_statespace(),
self.snapshot,
self.val_accessor(possibly_missing),
self.val_pytype,
Expand All @@ -1468,7 +1468,7 @@ def __iter__(self):
with NoTracing():
arr_var, len_var = self.var
iter_cache = self._iter_cache
space = self.statespace
space = context_statespace()
idx = 0
arr_sort = self._arr().sort()
is_missing = self.val_missing_checker
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def __init__(self, smtvar: Union[str, z3.ExprRef], typ: Type):
SymbolicDictOrSet.__init__(self, smtvar, typ)
self._iter_cache: List[z3.Const] = []
self.empty = z3.K(self._arr().sort().domain(), False)
self.statespace.add((self._arr() == self.empty) == (self._len() == 0))
context_statespace().add((self._arr() == self.empty) == (self._len() == 0))

def __ch_realize__(self):
return python_type(self)(map(realize, self))
Expand Down Expand Up @@ -1543,12 +1543,13 @@ def __eq__(self, other):

def __init_var__(self, typ, varname):
assert typ == self.python_type
space = context_statespace()
return (
z3.Const(
varname + "_map" + self.statespace.uniq(),
varname + "_map" + space.uniq(),
z3.ArraySort(self.smt_key_sort, _SMT_BOOL_SORT),
),
z3.Const(varname + "_len" + self.statespace.uniq(), _SMT_INT_SORT),
z3.Const(varname + "_len" + space.uniq(), _SMT_INT_SORT),
)

def __contains__(self, key):
Expand All @@ -1572,7 +1573,7 @@ def __iter__(self):
with NoTracing():
arr_var, len_var = self.var
iter_cache = self._iter_cache
space = self.statespace
space = context_statespace()
idx = 0
arr_sort = self._arr().sort()
keys_on_heap = is_heapref_sort(arr_sort.domain())
Expand Down Expand Up @@ -1611,9 +1612,7 @@ def __iter__(self):
arr_var = remaining
# In this conditional, we reconcile the parallel symbolic variables for length
# and contents:
if self.statespace.choose_possible(
arr_var != self.empty, probability_true=0.0
):
if space.choose_possible(arr_var != self.empty, probability_true=0.0):
raise IgnoreAttempt("SymbolicSet in inconsistent state")

def _set_op(self, attr, other):
Expand Down Expand Up @@ -1799,14 +1798,15 @@ def __init__(self, smtvar: Union[str, z3.ExprRef], typ: Any):

SymbolicValue.__init__(self, smtvar, typ)
len_var = self._len()
self.statespace.add(len_var >= 0)
context_statespace().add(len_var >= 0)

def __init_var__(self, typ, varname):
assert typ == self.python_type
arr_smt_type = z3.ArraySort(_SMT_INT_SORT, self.item_smt_sort)
space = context_statespace()
return (
z3.Const(varname + "_map" + self.statespace.uniq(), arr_smt_type),
z3.Const(varname + "_len" + self.statespace.uniq(), _SMT_INT_SORT),
z3.Const(varname + "_map" + space.uniq(), arr_smt_type),
z3.Const(varname + "_len" + space.uniq(), _SMT_INT_SORT),
)

def _arr(self):
Expand Down Expand Up @@ -1872,7 +1872,7 @@ def __radd__(self, other: object):
return NotImplemented

def __contains__(self, other):
space = self.statespace
space = context_statespace()
with NoTracing():
if not is_heapref_sort(self.item_smt_sort):
smt_other = self.ch_item_type._coerce_to_smt_sort(other)
Expand All @@ -1895,7 +1895,7 @@ def __contains__(self, other):
return False

def __getitem__(self, i):
space = self.statespace
space = context_statespace()
with NoTracing():
if (
isinstance(i, slice)
Expand Down Expand Up @@ -2127,7 +2127,7 @@ def _is_superclass_of_(self, other):
if type(other) is SymbolicType:
# Prefer it this way because only _is_subcless_of_ does the type cap lowering.
return other._is_subclass_of_(self)
space = self.statespace
space = context_statespace()
coerced = SymbolicType._coerce_to_smt_sort(other)
if coerced is None:
return False
Expand All @@ -2143,7 +2143,7 @@ def _is_subclass_of_(self, other):
assert not is_tracing()
if self is SymbolicType:
return False
space = self.statespace
space = context_statespace()
coerced = SymbolicType._coerce_to_smt_sort(other)
if coerced is None:
return False
Expand Down Expand Up @@ -2185,7 +2185,7 @@ def _realized(self):
def _realize(self) -> Type:
with NoTracing():
cap = self.pytype_cap
space = self.statespace
space = context_statespace()
type_repo = space.extra(SymbolicTypeRepository)
if cap is object:
# We don't attempt every possible Python type! Just some basic ones.
Expand Down Expand Up @@ -3341,7 +3341,7 @@ def _smt_promote_literal(cls, literal) -> Optional[z3.SortRef]:
return None

def __ch_realize__(self) -> object:
codepoints = self.statespace.find_model_value(self.var)
codepoints = context_statespace().find_model_value(self.var)
return "".join(chr(x) for x in codepoints)

def __copy__(self):
Expand Down Expand Up @@ -3386,7 +3386,6 @@ def __radd__(self, other):
raise TypeError

def __mul__(self, other):
self.statespace
if isinstance(other, Integral):
if other <= 1:
return self if other == 1 else ""
Expand All @@ -3409,7 +3408,7 @@ def __contains__(self, other):
def __getitem__(self, i: Union[int, slice]):
with NoTracing():
idx_or_pair = process_slice_vs_symbolic_len(
self.statespace, i, z3.Length(self.var)
context_statespace(), i, z3.Length(self.var)
)
if isinstance(idx_or_pair, tuple):
(start, stop) = idx_or_pair
Expand All @@ -3427,7 +3426,7 @@ def find(self, substr, start=None, end=None):
if not isinstance(substr, str):
raise TypeError
with NoTracing():
space = self.statespace
space = context_statespace()
smt_my_len = z3.Length(self.var)
if start is None and end is None:
smt_start = z3IntVal(0)
Expand Down Expand Up @@ -3493,7 +3492,7 @@ def rfind(self, substr, start=None, end=None) -> Union[int, SymbolicInt]:
if not isinstance(substr, str):
raise TypeError
with NoTracing():
space = self.statespace
space = context_statespace()
smt_my_len = z3.Length(self.var)
if start is None and end is None:
smt_start = z3IntVal(0)
Expand Down

0 comments on commit e698c34

Please sign in to comment.