Skip to content

Commit

Permalink
Today's progress (to be rebased)
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss committed Jan 16, 2025
1 parent 3a0c5bb commit 4dbfe36
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 154 deletions.
108 changes: 81 additions & 27 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ use crate::imports::{CLASSICAL_REGISTER, QUANTUM_REGISTER, REGISTER};
use crate::register::{Register, RegisterAsKey};
use crate::{BitType, ToPyBit};
use hashbrown::HashMap;
use indexmap::{Equivalent, IndexSet};
use indexmap::IndexSet;
use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use std::borrow::Borrow;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;
Expand Down Expand Up @@ -245,25 +244,29 @@ pub struct NewBitData<T: From<BitType>, R: Register + Hash + Eq> {
bits: Vec<OnceLock<PyObject>>,
/// Maps Python bits to native type.
indices: HashMap<BitAsKey, T>,
/// Maps Register keys to indices
reg_keys: HashMap<RegisterAsKey, u32>,
/// Mapping between bit index and its register info
bit_info: Vec<Option<BitInfo>>,
/// Registers in the circuit
registry: IndexSet<R>,
registry: Vec<R>,
/// Registers in Python
registers: Vec<OnceLock<PyObject>>,
/// Cached Python bits
cached_py_bits: OnceLock<Py<PyList>>,
/// Cached Python registers
cached_py_regs: OnceLock<Py<PyList>>,
}

impl<T, R> NewBitData<T, R>
where
T: From<BitType> + Copy + Debug + ToPyBit,
R: Register<Bit = T>
+ Equivalent<RegisterAsKey>
+ for<'a> Borrow<&'a RegisterAsKey>
+ Hash
+ Eq
+ From<(usize, Option<String>)>
+ for<'a> From<&'a [T]>
+ for<'a> From<(&'a [T], Option<String>)>,
+ for<'a> From<(&'a [T], String)>,
BitType: From<T>,
{
pub fn new(description: String) -> Self {
Expand All @@ -272,19 +275,25 @@ where
bits: Vec::new(),
indices: HashMap::new(),
bit_info: Vec::new(),
registry: IndexSet::new(),
registry: Vec::new(),
registers: Vec::new(),
cached_py_bits: OnceLock::new(),
cached_py_regs: OnceLock::new(),
reg_keys: HashMap::new(),
}
}

pub fn with_capacity(description: String, capacity: usize) -> Self {
pub fn with_capacity(description: String, bit_capacity: usize, reg_capacity: usize) -> Self {
NewBitData {
description,
bits: Vec::with_capacity(capacity),
indices: HashMap::with_capacity(capacity),
bit_info: Vec::with_capacity(capacity),
registry: IndexSet::with_capacity(capacity),
registers: Vec::with_capacity(capacity),
bits: Vec::with_capacity(bit_capacity),
indices: HashMap::with_capacity(bit_capacity),
bit_info: Vec::with_capacity(bit_capacity),
registry: Vec::with_capacity(reg_capacity),
registers: Vec::with_capacity(reg_capacity),
cached_py_bits: OnceLock::new(),
cached_py_regs: OnceLock::new(),
reg_keys: HashMap::with_capacity(reg_capacity)
}
}

Expand All @@ -302,12 +311,6 @@ where
self.bits.is_empty()
}

/// Gets a reference to the underlying vector of Python bits.
#[inline]
pub fn bits(&self) -> &Vec<OnceLock<PyObject>> {
&self.bits
}

/// Adds a register onto the [BitData] of the circuit.
pub fn add_register(
&mut self,
Expand All @@ -318,7 +321,7 @@ where
match (size, bits) {
(None, None) => panic!("You should at least provide either a size or the bit indices."),
(None, Some(bits)) => {
let reg: R = (bits, name).into();
let reg: R = if let Some(name) = name {(bits, name).into()} else {bits.into()};
let idx = self.registry.len().try_into().unwrap_or_else(|_| {
panic!(
"The {} registry in this circuit has reached its maximum capacity.",
Expand All @@ -345,20 +348,22 @@ where
))
}
}
self.registry.insert(reg);
self.reg_keys.insert(reg.as_key().clone(), idx);
self.registry.push(reg);
self.registers.push(OnceLock::new());
idx
}
(Some(size), None) => {
let bits: Vec<T> = (0..size).map(|_| self.add_bit()).collect();
let reg = (bits.as_slice(), name).into();
let reg: R = if let Some(name) = name {(bits.as_slice(), name).into()} else {bits.as_slice().into()};
let idx = self.registry.len().try_into().unwrap_or_else(|_| {
panic!(
"The {} registry in this circuit has reached its maximum capacity.",
self.description
)
});
self.registry.insert(reg);
self.reg_keys.insert(reg.as_key().clone(), idx);
self.registry.push(reg);
self.registers.push(OnceLock::new());
idx
}
Expand All @@ -384,12 +389,14 @@ where
}

/// Retrieves a register by its index within the circuit
#[inline]
pub fn get_register(&self, index: u32) -> Option<&R> {
self.registry.get_index(index as usize)
self.registry.get(index as usize)
}

#[inline]
pub fn get_register_by_key(&self, key: &RegisterAsKey) -> Option<&R> {
self.registry.get(&key)
self.reg_keys.get(key).and_then(|idx| self.get_register(*idx))
}

// =======================
Expand All @@ -401,6 +408,24 @@ where
pub fn py_find_bit(&self, bit: &Bound<PyAny>) -> Option<T> {
self.indices.get(&BitAsKey::new(bit)).copied()
}

/// Gets a reference to the cached Python list, maintained by
/// this instance.
#[inline]
pub fn py_cached_bits(&self, py: Python) -> &Py<PyList> {
&self.cached_py_bits.get_or_init(|| PyList::empty_bound(py).into())
}

/// Gets a reference to the underlying vector of Python bits.
// #[inline]
// pub fn py_bits<'a> (&'a self, py: Python<'a>) -> impl ExactSizeIterator<Item = &PyObject> + 'a {
// self.bits.iter().enumerate().map(move |(idx, bit)| {
// if bit.get().is_none() {
// self.py_get_bit(py, T::from(idx.try_into().unwrap()));
// }
// bit.get().unwrap()
// })
// }

/// Map the provided Python bits to their native indices.
/// An error is returned if any bit is not registered.
Expand All @@ -425,10 +450,20 @@ where
v.map(|x| x.into_iter())
}

/// Map the provided native indices to the corresponding Python
/// bit instances.
/// Panics if any of the indices are out of range.
pub fn py_map_indices(&mut self, py: Python, bits: &[T]) -> PyResult<impl ExactSizeIterator<Item = &Py<PyAny>>> {
let v: Vec<_> = bits.iter().map(|i| -> PyResult<&PyObject> {
Ok(self.py_get_bit(py, *i)?.unwrap())
}).collect::<PyResult<_>>()?;
Ok(v.into_iter())
}

/// Gets the Python bit corresponding to the given native
/// bit index.
#[inline]
pub fn py_get_bit(&mut self, py: Python, index: T) -> PyResult<Option<&PyObject>> {
pub fn py_get_bit(&self, py: Python, index: T) -> PyResult<Option<&PyObject>> {
/*
For this method we want to make sure a couple of things are done first:
Expand Down Expand Up @@ -472,7 +507,7 @@ where
}

/// Retrieves a register instance from Python based on the rust description.
pub fn py_get_register(&mut self, py: Python, index: u32) -> PyResult<Option<&PyObject>> {
pub fn py_get_register(&self, py: Python, index: u32) -> PyResult<Option<&PyObject>> {
let index_as_usize = index as usize;
// First check if the cell is in range if not, return none
if self.registers.get(index_as_usize).is_none() {
Expand Down Expand Up @@ -533,6 +568,12 @@ where
pub fn py_add_bit(&mut self, bit: &Bound<PyAny>, strict: bool) -> PyResult<T> {
let py: Python<'_> = bit.py();

if self.bits.len() != self.cached_py_bits.get_or_init(|| PyList::empty_bound(py).into()).bind(bit.py()).len() {
return Err(PyRuntimeError::new_err(
format!("This circuit's {} list has become out of sync with the circuit data. Did something modify it?", self.description)
));
}

let idx: BitType = self.bits.len().try_into().map_err(|_| {
PyRuntimeError::new_err(format!(
"The number of {} in the circuit has exceeded the maximum capacity",
Expand All @@ -546,6 +587,7 @@ where
{
self.bit_info.push(None);
self.bits.push(bit.into_py(py).into());
self.cached_py_bits.get_or_init(|| PyList::empty_bound(py).into()).bind(py).append(bit)?;
// self.cached.bind(py).append(bit)?;
} else if strict {
return Err(PyValueError::new_err(format!(
Expand All @@ -557,6 +599,13 @@ where
}

pub fn py_add_register(&mut self, register: &Bound<PyAny>) -> PyResult<u32> {
let py = register.py();
if self.registers.len() != self.cached_py_regs.get_or_init(|| PyList::empty_bound(py).into()).bind(py).len() {
return Err(PyRuntimeError::new_err(
format!("This circuit's {} list has become out of sync with the circuit data. Did something modify it?", self.description)
));
}

// let index: u32 = self.registers.len().try_into().map_err(|_| {
// PyRuntimeError::new_err(format!(
// "The number of {} registers in the circuit has exceeded the maximum capacity",
Expand All @@ -578,6 +627,7 @@ where

let name: String = register.getattr("name")?.extract()?;
self.registers.push(register.clone().unbind().into());
self.cached_py_regs.get_or_init(|| PyList::empty_bound(py).into()).bind(py).append(&register)?;
Ok(self.add_register(Some(name), None, Some(&bits)))
}

Expand Down Expand Up @@ -613,6 +663,10 @@ where
Ok(())
}

pub fn py_bits(&self, py: Python) -> PyResult<Vec<&PyObject>> {
(0..self.len()).map(|idx| self.py_get_bit(py, (idx as u32).into()).map( |bit| bit.unwrap())).collect::<PyResult<_>>()
}

/// Called during Python garbage collection, only!.
/// Note: INVALIDATES THIS INSTANCE.
pub fn dispose(&mut self) {
Expand Down
Loading

0 comments on commit 4dbfe36

Please sign in to comment.