Skip to content

Commit

Permalink
Remove _value_equality_values_cls_ and update tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
codrut3 committed Mar 2, 2025
1 parent 73914bb commit 2e2df0a
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 25 deletions.
10 changes: 7 additions & 3 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,9 +941,13 @@ def test_cx_cz_stabilizer(gate):


def test_phase_by_xy():
assert cirq.phase_by(cirq.X, 0.25, 0) == cirq.Y
assert cirq.phase_by(cirq.X**0.5, 0.25, 0) == cirq.Y**0.5
assert cirq.phase_by(cirq.X**-0.5, 0.25, 0) == cirq.Y**-0.5
assert cirq.phase_by(cirq.X, 0.25, 0) == cirq.PhasedXPowGate(phase_exponent=0.5)
assert cirq.phase_by(cirq.X**0.5, 0.25, 0) == cirq.PhasedXPowGate(
exponent=0.5, phase_exponent=0.5
)
assert cirq.phase_by(cirq.X**-0.5, 0.25, 0) == cirq.PhasedXPowGate(
exponent=-0.5, phase_exponent=0.5
)


def test_ixyz_circuit_diagram():
Expand Down
8 changes: 1 addition & 7 deletions cirq-core/cirq/ops/phased_x_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from cirq.ops import raw_types


@value.value_equality(manual_cls=True, approximate=True)
@value.value_equality(approximate=True)
class PhasedXPowGate(raw_types.Gate):
r"""A gate equivalent to $Z^{-p} X^t Z^{p}$ (in time order).
Expand Down Expand Up @@ -242,14 +242,8 @@ def _canonical_exponent(self):

return self._exponent % period

def _value_equality_values_cls_(self):
return PhasedXPowGate

def _value_equality_values_(self):
return self.phase_exponent, self._canonical_exponent, self._global_shift

def _value_equality_approximate_values_(self):
return self.phase_exponent, self._canonical_exponent, self._global_shift

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['phase_exponent', 'exponent', 'global_shift'])
43 changes: 34 additions & 9 deletions cirq-core/cirq/transformers/eject_phased_paulis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,17 @@ def test_crosses_czs():
# Partial CZ.
assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.CZ(a, b) ** 0.25]),
expected=quick_circuit([cirq.Z(b) ** 0.25], [cirq.CZ(a, b) ** -0.25], [cirq.X(a)]),
expected=quick_circuit(
[cirq.Z(b) ** 0.25],
[cirq.CZ(a, b) ** -0.25],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
)
assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.CZ(a, b) ** x]),
expected=quick_circuit([cirq.Z(b) ** x], [cirq.CZ(a, b) ** -x], [cirq.X(a)]),
expected=quick_circuit(
[cirq.Z(b) ** x], [cirq.CZ(a, b) ** -x], [cirq.PhasedXPowGate(phase_exponent=0)(a)]
),
eject_parameterized=True,
)

Expand Down Expand Up @@ -380,7 +386,8 @@ def test_phases_partial_ws():
[cirq.X(q)], [cirq.PhasedXPowGate(phase_exponent=0.25, exponent=0.5).on(q)]
),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=-0.25, exponent=0.5).on(q)], [cirq.X(q)]
[cirq.PhasedXPowGate(phase_exponent=-0.25, exponent=0.5).on(q)],
[cirq.PhasedXPowGate(phase_exponent=0)(q)],
),
)

Expand All @@ -398,7 +405,8 @@ def test_phases_partial_ws():
[cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.75).on(q)],
),
expected=quick_circuit(
[cirq.X(q) ** 0.75], [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)]
[cirq.PhasedXPowGate(phase_exponent=0)(q) ** 0.75],
[cirq.PhasedXPowGate(phase_exponent=0.25).on(q)],
),
)

Expand All @@ -407,7 +415,8 @@ def test_phases_partial_ws():
[cirq.X(q)], [cirq.PhasedXPowGate(exponent=-0.25, phase_exponent=0.5).on(q)]
),
expected=quick_circuit(
[cirq.PhasedXPowGate(exponent=-0.25, phase_exponent=-0.5).on(q)], [cirq.X(q)]
[cirq.PhasedXPowGate(exponent=-0.25, phase_exponent=-0.5).on(q)],
[cirq.PhasedXPowGate(phase_exponent=0)(q)],
),
)

Expand All @@ -431,18 +440,30 @@ def test_blocked_by_unknown_and_symbols(sym):

assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.SWAP(a, b)], [cirq.X(a)]),
expected=quick_circuit([cirq.X(a)], [cirq.SWAP(a, b)], [cirq.X(a)]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
[cirq.SWAP(a, b)],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
)

assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.Z(a) ** sym], [cirq.X(a)]),
expected=quick_circuit([cirq.X(a)], [cirq.Z(a) ** sym], [cirq.X(a)]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
[cirq.Z(a) ** sym],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
compare_unitaries=False,
)

assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.CZ(a, b) ** sym], [cirq.X(a)]),
expected=quick_circuit([cirq.X(a)], [cirq.CZ(a, b) ** sym], [cirq.X(a)]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
[cirq.CZ(a, b) ** sym],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
compare_unitaries=False,
)

Expand All @@ -453,7 +474,11 @@ def test_blocked_by_nocompile_tag():

assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.CZ(a, b).with_tags("nocompile")], [cirq.X(a)]),
expected=quick_circuit([cirq.X(a)], [cirq.CZ(a, b).with_tags("nocompile")], [cirq.X(a)]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
[cirq.CZ(a, b).with_tags("nocompile")],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
with_context=True,
)

Expand Down
6 changes: 5 additions & 1 deletion cirq-core/cirq/transformers/eject_z_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def test_z_pushes_past_xy_and_phases_it():
assert_optimizes(
before=cirq.Circuit([cirq.Moment([cirq.Z(q) ** 0.5]), cirq.Moment([cirq.Y(q) ** 0.25])]),
expected=cirq.Circuit(
[cirq.Moment(), cirq.Moment([cirq.X(q) ** 0.25]), cirq.Moment([cirq.Z(q) ** 0.5])]
[
cirq.Moment(),
cirq.Moment([cirq.PhasedXPowGate(phase_exponent=0)(q) ** 0.25]),
cirq.Moment([cirq.Z(q) ** 0.5]),
]
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_merge_single_qubit_gates_to_phased_x_and_z():
optimized=cirq.merge_single_qubit_gates_to_phased_x_and_z(c),
expected=cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=1)(a),
cirq.Y(b) ** 0.5,
cirq.PhasedXPowGate(phase_exponent=0.5)(b) ** 0.5,
cirq.CZ(a, b),
(cirq.PhasedXPowGate(phase_exponent=-0.5)(a)) ** 0.5,
cirq.measure(b, key="m"),
Expand Down
6 changes: 3 additions & 3 deletions cirq-google/cirq_google/api/v1/programs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def assert_proto_dict_convert(gate: cirq.Gate, proto: operations_pb2.Operation,
def test_protobuf_round_trip():
qubits = cirq.GridQubit.rect(1, 5)
circuit = cirq.Circuit(
[cirq.X(q) ** 0.5 for q in qubits],
[cirq.PhasedXPowGate(phase_exponent=0)(q) ** 0.5 for q in qubits],
[
cirq.CZ(q, q2)
for q in [cirq.GridQubit(0, 0)]
Expand Down Expand Up @@ -245,7 +245,7 @@ def test_w_to_proto():
)
assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))

gate = cirq.X**0.25
gate = cirq.PhasedXPowGate(exponent=0.25, phase_exponent=0)
proto = operations_pb2.Operation(
exp_w=operations_pb2.ExpW(
target=operations_pb2.Qubit(row=2, col=3),
Expand All @@ -255,7 +255,7 @@ def test_w_to_proto():
)
assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))

gate = cirq.Y**0.25
gate = cirq.PhasedXPowGate(exponent=0.25, phase_exponent=0.5)
proto = operations_pb2.Operation(
exp_w=operations_pb2.ExpW(
target=operations_pb2.Qubit(row=2, col=3),
Expand Down
3 changes: 2 additions & 1 deletion cirq-google/cirq_google/engine/engine_program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def test_get_circuit_v1(get_program_async):
@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async')
def test_get_circuit_v2(get_program_async):
circuit = cirq.Circuit(
cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result')
cirq.PhasedXPowGate(phase_exponent=0)(cirq.GridQubit(5, 2)) ** 0.5,
cirq.measure(cirq.GridQubit(5, 2), key='result'),
)

program = cg.EngineProgram('a', 'b', EngineContext())
Expand Down

0 comments on commit 2e2df0a

Please sign in to comment.