Skip to content

Commit

Permalink
Remove stale variable_class_operations set from Target (Qiskit#12957
Browse files Browse the repository at this point in the history
)

* Fix: Return None for variadic properties.

* Fix: Remove `variable_class_operations` from `Target`.
- When performing serialization, we were forgetting to include `variable_class_operations` set of names in the state mapping. Since the nature of `TargetOperation` is to work as an enum of either `Instruction` instances or class aliases that would represent `Variadic` instructiions. The usage of that structure was redundand, so it was removed.
- `num_qubits` returns an instance of `u32`, callers will need to make sure they're dealing with a `NormalOperation`.
- `params` behaves more similarly, returning a slice of `Param` instances. Will panic if called on a `Variadic` operation.
- Re-adapt the code to work without `variable_class_operations`.
- Add test case to check for something similar to what was mentioned by @doichanj in Qiskit#12953.

* Fix: Use `UnitaryGate` as the example in test-case.
- Move import of `pickle` to top of the file.
  • Loading branch information
raynelfss authored Oct 24, 2024
1 parent 214e0a4 commit f2e07bc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 34 deletions.
75 changes: 41 additions & 34 deletions crates/accelerate/src/target_transpiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::ops::Index;
use ahash::RandomState;

use hashbrown::HashSet;
use indexmap::{IndexMap, IndexSet};
use indexmap::IndexMap;
use itertools::Itertools;
use nullable_index_map::NullableIndexMap;
use pyo3::{
Expand Down Expand Up @@ -57,7 +57,7 @@ type GateMapState = Vec<(String, Vec<(Option<Qargs>, Option<InstructionPropertie

/// Represents a Qiskit `Gate` object or a Variadic instruction.
/// Keeps a reference to its Python instance for caching purposes.
#[derive(Debug, Clone, FromPyObject)]
#[derive(FromPyObject, Debug, Clone)]
pub(crate) enum TargetOperation {
Normal(NormalOperation),
Variadic(PyObject),
Expand All @@ -82,19 +82,23 @@ impl ToPyObject for TargetOperation {
}

impl TargetOperation {
fn num_qubits(&self) -> u32 {
/// Gets the number of qubits of a [TargetOperation], will panic if the operation is [TargetOperation::Variadic].
pub fn num_qubits(&self) -> u32 {
match &self {
Self::Normal(normal) => normal.operation.view().num_qubits(),
Self::Normal(normal) => normal.operation.num_qubits(),
Self::Variadic(_) => {
unreachable!("'num_qubits' property is reserved for normal operations only.")
panic!("'num_qubits' property doesn't exist for Variadic operations")
}
}
}

fn params(&self) -> &[Param] {
/// Gets the parameters of a [TargetOperation], will panic if the operation is [TargetOperation::Variadic].
pub fn params(&self) -> &[Param] {
match &self {
TargetOperation::Normal(normal) => normal.params.as_slice(),
TargetOperation::Variadic(_) => &[],
TargetOperation::Variadic(_) => {
panic!("'parameters' property doesn't exist for Variadic operations")
}
}
}
}
Expand Down Expand Up @@ -173,7 +177,6 @@ pub(crate) struct Target {
#[pyo3(get)]
_gate_name_map: IndexMap<String, TargetOperation, RandomState>,
global_operations: IndexMap<u32, HashSet<String>, RandomState>,
variable_class_operations: IndexSet<String, RandomState>,
qarg_gate_map: NullableIndexMap<Qargs, Option<HashSet<String>>>,
non_global_strict_basis: Option<Vec<String>>,
non_global_basis: Option<Vec<String>>,
Expand Down Expand Up @@ -269,7 +272,6 @@ impl Target {
concurrent_measurements,
gate_map: GateMap::default(),
_gate_name_map: IndexMap::default(),
variable_class_operations: IndexSet::default(),
global_operations: IndexMap::default(),
qarg_gate_map: NullableIndexMap::default(),
non_global_basis: None,
Expand Down Expand Up @@ -302,16 +304,15 @@ impl Target {
)));
}
let mut qargs_val: PropsMap;
match instruction {
match &instruction {
TargetOperation::Variadic(_) => {
qargs_val = PropsMap::with_capacity(1);
qargs_val.extend([(None, None)]);
self.variable_class_operations.insert(name.to_string());
}
TargetOperation::Normal(_) => {
TargetOperation::Normal(normal) => {
if let Some(mut properties) = properties {
qargs_val = PropsMap::with_capacity(properties.len());
let inst_num_qubits = instruction.num_qubits();
let inst_num_qubits = normal.operation.view().num_qubits();
if properties.contains_key(None) {
self.global_operations
.entry(inst_num_qubits)
Expand Down Expand Up @@ -619,7 +620,7 @@ impl Target {
} else if let Some(operation_name) = operation_name {
if let Some(parameters) = parameters {
if let Some(obj) = self._gate_name_map.get(&operation_name) {
if self.variable_class_operations.contains(&operation_name) {
if matches!(obj, TargetOperation::Variadic(_)) {
if let Some(_qargs) = qargs {
let qarg_set: HashSet<PhysicalQubit> = _qargs.iter().cloned().collect();
return Ok(_qargs
Expand Down Expand Up @@ -1053,8 +1054,8 @@ impl Target {
if let Some(Some(qarg_gate_map_arg)) = self.qarg_gate_map.get(qargs).as_ref() {
res.extend(qarg_gate_map_arg.iter().map(|key| key.as_str()));
}
for name in self._gate_name_map.keys() {
if self.variable_class_operations.contains(name) {
for (name, obj) in self._gate_name_map.iter() {
if matches!(obj, TargetOperation::Variadic(_)) {
res.insert(name);
}
}
Expand Down Expand Up @@ -1160,34 +1161,40 @@ impl Target {
}
if gate_prop_name.contains_key(None) {
let obj = &self._gate_name_map[operation_name];
if self.variable_class_operations.contains(operation_name) {
match obj {
TargetOperation::Variadic(_) => {
return qargs.is_none()
|| _qargs.iter().all(|qarg| {
qarg.index() <= self.num_qubits.unwrap_or_default()
}) && qarg_set.len() == _qargs.len();
}
TargetOperation::Normal(obj) => {
let qubit_comparison = obj.operation.num_qubits();
return qubit_comparison == _qargs.len() as u32
&& _qargs.iter().all(|qarg| {
qarg.index() < self.num_qubits.unwrap_or_default()
});
}
}
}
} else {
// Duplicate case is if it contains none
let obj = &self._gate_name_map[operation_name];
match obj {
TargetOperation::Variadic(_) => {
return qargs.is_none()
|| _qargs.iter().all(|qarg| {
qarg.index() <= self.num_qubits.unwrap_or_default()
}) && qarg_set.len() == _qargs.len();
} else {
let qubit_comparison = obj.num_qubits();
}
TargetOperation::Normal(obj) => {
let qubit_comparison = obj.operation.num_qubits();
return qubit_comparison == _qargs.len() as u32
&& _qargs.iter().all(|qarg| {
qarg.index() < self.num_qubits.unwrap_or_default()
});
}
}
} else {
// Duplicate case is if it contains none
if self.variable_class_operations.contains(operation_name) {
return qargs.is_none()
|| _qargs
.iter()
.all(|qarg| qarg.index() <= self.num_qubits.unwrap_or_default())
&& qarg_set.len() == _qargs.len();
} else {
let qubit_comparison = self._gate_name_map[operation_name].num_qubits();
return qubit_comparison == _qargs.len() as u32
&& _qargs
.iter()
.all(|qarg| qarg.index() < self.num_qubits.unwrap_or_default());
}
}
} else {
return true;
Expand Down
19 changes: 19 additions & 0 deletions test/python/transpiler/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# that they have been altered from the originals.

# pylint: disable=missing-docstring
from pickle import loads, dumps

import math
import numpy as np
Expand All @@ -30,6 +31,7 @@
CCXGate,
RZXGate,
CZGate,
UnitaryGate,
)
from qiskit.circuit import IfElseOp, ForLoopOp, WhileLoopOp, SwitchCaseOp
from qiskit.circuit.measure import Measure
Expand Down Expand Up @@ -1166,6 +1168,23 @@ def test_instruction_supported_no_args(self):
def test_instruction_supported_no_operation(self):
self.assertFalse(self.ibm_target.instruction_supported(qargs=(0,), parameters=[math.pi]))

def test_target_serialization_preserve_variadic(self):
"""Checks that variadics are still seen as variadic after serialization"""

target = Target("test", 2)
# Add variadic example gate with no properties.
target.add_instruction(UnitaryGate, None, "u_var")

# Check that this this instruction is compatible with qargs (0,). Should be
# true since variadic operation can be used with any valid qargs.
self.assertTrue(target.instruction_supported("u_var", (0, 1)))

# Rebuild the target using serialization
deserialized_target = loads(dumps(target))

# Perform check again, should not throw exception
self.assertTrue(deserialized_target.instruction_supported("u_var", (0, 1)))


class TestPulseTarget(QiskitTestCase):
def setUp(self):
Expand Down

0 comments on commit f2e07bc

Please sign in to comment.