Skip to content

Commit

Permalink
[verifier] Improve reasoning about constant square roots in Z3 (#199)
Browse files Browse the repository at this point in the history
* [verifier] Improve reasoning about constant square roots in Z3

* code format

* fix variable name
  • Loading branch information
xumingkuan authored Jan 24, 2025
1 parent 0cdc6f9 commit b7bfa78
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 37 deletions.
95 changes: 59 additions & 36 deletions src/python/verifier/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@

import z3

# helper constants

sqrt2 = z3.Real("sqrt2")
sqrt3 = z3.Real("sqrt3")
sqrt5 = z3.Real("sqrt5")
sqrt_of_5_minus_sqrt5 = z3.Real("sqrt_of_5_minus_sqrt5")
kConstantEquations = [
sqrt2 * sqrt2 == 2,
sqrt3 * sqrt3 == 3,
sqrt5 * sqrt5 == 5,
sqrt_of_5_minus_sqrt5 * sqrt_of_5_minus_sqrt5 == 5 - sqrt5,
]


# helper methods


Expand Down Expand Up @@ -82,7 +96,7 @@ def mult(x, y):
return add(y, z)


def pi(n):
def pi(n, use_z3=True):
# This function handles fractions of pi with integer denominators.
assert isinstance(n, (int, float))
if isinstance(n, float):
Expand All @@ -98,9 +112,10 @@ def pi(n):
elif n == 2:
return 0, 1
elif n == 4:
cos_a = z3.Sqrt(2) / 2
sin_a = z3.Sqrt(2) / 2
return cos_a, sin_a
if use_z3:
return sqrt2 / 2, sqrt2 / 2
else:
return math.sqrt(2) / 2, math.sqrt(2) / 2
# Half-Angle Formula.
elif n % 2 == 0:
return half(pi(n // 2))
Expand All @@ -112,12 +127,20 @@ def pi(n):
# The equations are horrendous, and probably not useful in practice.
# If necessary, they could be implemented (at least for n = 17).
elif n == 3:
cos_a = 1 / 2
sin_a = z3.Sqrt(3) / 2
if use_z3:
cos_a = 1 / 2
sin_a = sqrt3 / 2
else:
cos_a = 1 / 2
sin_a = math.sqrt(3) / 2
return cos_a, sin_a
elif n == 5:
cos_a = (z3.Sqrt(5) + 1) / 4
sin_a = z3.Sqrt(2) * z3.Sqrt(5 - z3.Sqrt(5)) / 4
if use_z3:
cos_a = (sqrt5 + 1) / 4
sin_a = sqrt2 * sqrt_of_5_minus_sqrt5 / 4
else:
cos_a = (math.sqrt(5) + 1) / 4
sin_a = math.sqrt(2) * math.sqrt(5 - math.sqrt(5)) / 4
return cos_a, sin_a
elif n == 15:
return add(mult(2, pi(5)), neg(pi(3)))
Expand Down Expand Up @@ -169,12 +192,12 @@ def u2(phi, l, use_z3=True):
cos_l, sin_l = l
if use_z3:
return [
[(1 / z3.Sqrt(2), 0), (-1 / z3.Sqrt(2) * cos_l, -1 / z3.Sqrt(2) * sin_l)],
[(1 / sqrt2, 0), (-1 / sqrt2 * cos_l, -1 / sqrt2 * sin_l)],
[
(1 / z3.Sqrt(2) * cos_phi, 1 / z3.Sqrt(2) * sin_phi),
(1 / sqrt2 * cos_phi, 1 / sqrt2 * sin_phi),
(
1 / z3.Sqrt(2) * (cos_l * cos_phi - sin_l * sin_phi),
1 / z3.Sqrt(2) * (sin_phi * cos_l + sin_l * cos_phi),
1 / sqrt2 * (cos_l * cos_phi - sin_l * sin_phi),
1 / sqrt2 * (sin_phi * cos_l + sin_l * cos_phi),
),
],
]
Expand Down Expand Up @@ -236,8 +259,8 @@ def cp(phi, use_z3=True):
def h(use_z3=True):
if use_z3:
return [
[(1 / z3.Sqrt(2), 0), (1 / z3.Sqrt(2), 0)],
[(1 / z3.Sqrt(2), 0), (-1 / z3.Sqrt(2), 0)],
[(1 / sqrt2, 0), (1 / sqrt2, 0)],
[(1 / sqrt2, 0), (-1 / sqrt2, 0)],
]
else:
return [
Expand All @@ -256,14 +279,14 @@ def sdg(use_z3=True):

def t(use_z3=True):
if use_z3:
return [[(1, 0), (0, 0)], [(0, 0), (z3.Sqrt(2) / 2, z3.Sqrt(2) / 2)]]
return [[(1, 0), (0, 0)], [(0, 0), (sqrt2 / 2, sqrt2 / 2)]]
else:
return [[(1, 0), (0, 0)], [(0, 0), (math.sqrt(2) / 2, math.sqrt(2) / 2)]]


def tdg(use_z3=True):
if use_z3:
return [[(1, 0), (0, 0)], [(0, 0), (z3.Sqrt(2) / 2, -z3.Sqrt(2) / 2)]]
return [[(1, 0), (0, 0)], [(0, 0), (sqrt2 / 2, -sqrt2 / 2)]]
else:
return [[(1, 0), (0, 0)], [(0, 0), (math.sqrt(2) / 2, -math.sqrt(2) / 2)]]

Expand All @@ -287,26 +310,26 @@ def pdg(phi, use_z3=True):
def rx1(use_z3=True):
if use_z3:
return [
[(z3.Sqrt(2) / 2, 0), (0, -z3.Sqrt(2) / 2)],
[(0, -z3.Sqrt(2) / 2), (z3.Sqrt(2) / 2, 0)],
[(sqrt2 / 2, 0), (0, -sqrt2 / 2)],
[(0, -sqrt2 / 2), (sqrt2 / 2, 0)],
]
else:
return [
[(math.Sqrt(2) / 2, 0), (0, -math.Sqrt(2) / 2)],
[(0, -math.Sqrt(2) / 2), (math.Sqrt(2) / 2, 0)],
[(math.sqrt(2) / 2, 0), (0, -math.sqrt(2) / 2)],
[(0, -math.sqrt(2) / 2), (math.sqrt(2) / 2, 0)],
]


def rx3(use_z3=True):
if use_z3:
return [
[(z3.Sqrt(2) / 2, 0), (0, z3.Sqrt(2) / 2)],
[(0, z3.Sqrt(2) / 2), (z3.Sqrt(2) / 2, 0)],
[(sqrt2 / 2, 0), (0, sqrt2 / 2)],
[(0, sqrt2 / 2), (sqrt2 / 2, 0)],
]
else:
return [
[(math.Sqrt(2) / 2, 0), (0, math.Sqrt(2) / 2)],
[(0, math.Sqrt(2) / 2), (math.Sqrt(2) / 2, 0)],
[(math.sqrt(2) / 2, 0), (0, math.sqrt(2) / 2)],
[(0, math.sqrt(2) / 2), (math.sqrt(2) / 2, 0)],
]


Expand All @@ -322,8 +345,8 @@ def cz(use_z3=True):
def ry1(use_z3=True):
if use_z3:
return [
[(z3.Sqrt(2) / 2, 0), (-z3.Sqrt(2) / 2, 0)],
[(z3.Sqrt(2) / 2, 0), (z3.Sqrt(2) / 2, 0)],
[(sqrt2 / 2, 0), (-sqrt2 / 2, 0)],
[(sqrt2 / 2, 0), (sqrt2 / 2, 0)],
]
else:
return [
Expand All @@ -335,8 +358,8 @@ def ry1(use_z3=True):
def ry3(use_z3=True):
if use_z3:
return [
[(-z3.Sqrt(2) / 2, 0), (-z3.Sqrt(2) / 2, 0)],
[(z3.Sqrt(2) / 2, 0), (-z3.Sqrt(2) / 2, 0)],
[(-sqrt2 / 2, 0), (-sqrt2 / 2, 0)],
[(sqrt2 / 2, 0), (-sqrt2 / 2, 0)],
]
else:
return [
Expand All @@ -348,10 +371,10 @@ def ry3(use_z3=True):
def rxx1(use_z3=True):
if use_z3:
return [
[(z3.Sqrt(2) / 2, 0), (0, 0), (0, 0), (-z3.Sqrt(2) / 2, 0)],
[(0, 0), (z3.Sqrt(2) / 2, 0), (-z3.Sqrt(2) / 2, 0), (0, 0)],
[(0, 0), (-z3.Sqrt(2) / 2, 0), (z3.Sqrt(2) / 2, 0), (0, 0)],
[(-z3.Sqrt(2) / 2, 0), (0, 0), (0, 0), (z3.Sqrt(2) / 2, 0)],
[(sqrt2 / 2, 0), (0, 0), (0, 0), (-sqrt2 / 2, 0)],
[(0, 0), (sqrt2 / 2, 0), (-sqrt2 / 2, 0), (0, 0)],
[(0, 0), (-sqrt2 / 2, 0), (sqrt2 / 2, 0), (0, 0)],
[(-sqrt2 / 2, 0), (0, 0), (0, 0), (sqrt2 / 2, 0)],
]
else:
return [
Expand All @@ -365,10 +388,10 @@ def rxx1(use_z3=True):
def rxx3(use_z3=True):
if use_z3:
return [
[(-z3.Sqrt(2) / 2, 0), (0, 0), (0, 0), (-z3.Sqrt(2) / 2, 0)],
[(0, 0), (-z3.Sqrt(2) / 2, 0), (-z3.Sqrt(2) / 2, 0), (0, 0)],
[(0, 0), (-z3.Sqrt(2) / 2, 0), (-z3.Sqrt(2) / 2, 0), (0, 0)],
[(-z3.Sqrt(2) / 2, 0), (0, 0), (0, 0), (-z3.Sqrt(2) / 2, 0)],
[(-sqrt2 / 2, 0), (0, 0), (0, 0), (-sqrt2 / 2, 0)],
[(0, 0), (-sqrt2 / 2, 0), (-sqrt2 / 2, 0), (0, 0)],
[(0, 0), (-sqrt2 / 2, 0), (-sqrt2 / 2, 0), (0, 0)],
[(-sqrt2 / 2, 0), (0, 0), (0, 0), (-sqrt2 / 2, 0)],
]
else:
return [
Expand Down
8 changes: 7 additions & 1 deletion src/python/verifier/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys

import z3
from gates import kConstantEquations # for constant square roots
from gates import add, compute, get_matrix, neg # for searching phase factors

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -317,9 +318,13 @@ def search_phase_factor_to_check_equivalence(
# Found a possible phase factor
# print(f'Checking phase factor {current_phase_factor_for_fingerprint}')
solver = z3.Solver()
# solver = z3.Tactic('qfnra-nlsat').solver()
solver.add(kConstantEquations)
solver.add(equations)
output_vec2_shifted = phase_shift(output_vec2, current_phase_factor_symbolic)
solver.add(z3.Not(z3.And(eq_vector(output_vec1, output_vec2_shifted))))
solver.add(
z3.simplify(z3.Not(z3.And(eq_vector(output_vec1, output_vec2_shifted))))
)
solver.set("timeout", timeout) # timeout in milliseconds
result = solver.check()
if result != z3.unsat:
Expand Down Expand Up @@ -411,6 +416,7 @@ def equivalent(
return False

solver = z3.Solver()
solver.add(kConstantEquations)
num_qubits = dag1_meta[meta_index_num_qubits]
equation_list = copy.deepcopy(
equation_list_for_params
Expand Down

0 comments on commit b7bfa78

Please sign in to comment.