Skip to content

Commit

Permalink
shift and add multiplication (#41)
Browse files Browse the repository at this point in the history
* shift and add multiplication
* test suite for mult_by_const
  • Loading branch information
dakk authored Apr 17, 2024
1 parent 74c61ab commit 3498a6c
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 37 deletions.
22 changes: 11 additions & 11 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Tuple"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[_replace_types_annotations(el) for el in _elts])

ann = ast.Subscript(
Expand All @@ -101,7 +101,7 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qlist"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value)
Expand All @@ -116,7 +116,7 @@ def _replace_types_annotations(ann, arg=None):
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qmatrix"
and hasattr(ann.slice, 'elts')
and hasattr(ann.slice, "elts")
):
_elts = ann.slice.elts
_ituple_row = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[2].value)
Expand Down Expand Up @@ -425,31 +425,31 @@ def visit_For(self, node): # noqa: C901
isinstance(iter, ast.Subscript)
and isinstance(iter.value, ast.Name)
and iter.value.id in self.env
and hasattr(iter.slice, 'value')
and hasattr(iter.slice, "value")
):
if isinstance(self.env[iter.value.id], ast.Tuple):
new_iter = self.env[iter.value.id].elts[iter.slice.value]

elif isinstance(self.env[iter.value.id], ast.Subscript):
_elts = self.env[iter.value.id].slice.elts[iter.slice.value]
elif isinstance(self.env[iter.value.id], ast.Subscript):
_elts = self.env[iter.value.id].slice.elts[iter.slice.value]

if isinstance(_elts, ast.Tuple):
_elts = _elts.elts

new_iter = [
ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=iter.value.id, ctx=ast.Load()),
slice=ast.Constant(value=iter.slice.value),
value=ast.Name(id=iter.value.id, ctx=ast.Load()),
slice=ast.Constant(value=iter.slice.value),
ctx=ast.Load(),
),
slice=ast.Constant(value=e),
)
for e in range(len(_elts))
]
else:
new_iter = iter
new_iter = iter

iter = new_iter

if isinstance(iter, ast.Constant) and isinstance(iter.value, ast.Tuple):
Expand Down
98 changes: 79 additions & 19 deletions qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sympy.logic import And, Not, Or, Xor, false, true

from . import TypeErrorException, _eq, _full_adder, _neq
from .qtype import Qtype, TExp, bin_to_bool_list, bool_list_to_bin
from .qtype import Qtype, TExp, TType, bin_to_bool_list, bool_list_to_bin


class QintImp(int, Qtype):
Expand Down Expand Up @@ -164,12 +164,86 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp:

return (cls if cls.BIT_SIZE > tleft_e[0].BIT_SIZE else tleft_e[0], sums)

@staticmethod
def mul_even_const(t_num: TExp, const: int, result_type: Qtype) -> TExp:
"""Multiply by an even const using shift and add
(x << 3) + (x << 1) # Here 10*x is computed as x*2^3 + x*2
"""

# Multiply t_num by the nearest n | 2**n < t_const
n = 1
while 2**n <= const:
n += 1
if 2**n > const:
n -= 1

result_ttype = cast(TType, result_type)

t_num_r = result_type.shift_left((result_ttype, t_num[1]), n)

# Shift t_const by t_const - 2**n
r = const - 2**n
if r > 0:
# Add the shift result to t_num
res = result_type.add(
(result_ttype, t_num_r[1]),
result_type.shift_left((result_ttype, t_num[1]), int(r / 2)),
)
else:
res = (result_ttype, t_num_r[1])

return res

@classmethod
def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901
# TODO: use RGQFT multiplier
def mul(cls, tleft_: TExp, tright_: TExp) -> TExp: # noqa: C901
if not issubclass(tleft_[0], Qtype):
raise TypeErrorException(tleft_[0], Qtype)
if not issubclass(tright_[0], Qtype):
raise TypeErrorException(tright_[0], Qtype)

def __mul_sizing(n, m):
if (n + m) <= 2:
return Qint2
elif (n + m) > 2 and (n + m) <= 4:
return Qint4
elif (n + m) > 4 and (n + m) <= 6:
return Qint6
elif (n + m) > 6 and (n + m) <= 8:
return Qint8
elif (n + m) > 8 and (n + m) <= 12:
return Qint12
elif (n + m) > 12 and (n + m) <= 16:
return Qint16
elif (n + m) > 16:
return Qint16
else:
raise Exception(f"Mul result size is too big ({n+m})")

# Fill constants so explicit typecast is not needed
if cls.is_const(tleft_):
tleft = tright_[0].fill(tleft_)
else:
tleft = tleft_

if cls.is_const(tright_):
tright = tleft_[0].fill(tright_)
else:
tright = tright_

n = len(tleft[1])
m = len(tright[1])

# If one operand is an even constant, use mul_even_const
if cls.is_const(tleft) or cls.is_const(tright):
t_num = tleft if cls.is_const(tright) else tright
t_const = tleft if cls.is_const(tleft) else tright
const = cast(int, cast(Qtype, t_const[0]).from_bool(t_const[1]))

if const % 2 == 0:
t = __mul_sizing(n, m)
res = cls.mul_even_const(t_num, const, t)
return t.crop(t.fill(res))

if n != m:
raise Exception(f"Mul works only on same size Qint: {n} != {m}")

Expand All @@ -190,22 +264,8 @@ def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901
if i + m < n + m:
product[i + m] = carry

if (n + m) <= 2:
return Qint2, product
elif (n + m) > 2 and (n + m) <= 4:
return Qint4, product
elif (n + m) > 4 and (n + m) <= 6:
return Qint6, product
elif (n + m) > 6 and (n + m) <= 8:
return Qint8, product
elif (n + m) > 8 and (n + m) <= 12:
return Qint12, product
elif (n + m) > 12 and (n + m) <= 16:
return Qint16, product
elif (n + m) > 16:
return Qint16.crop((Qint16, product))

raise Exception(f"Mul result size is too big ({n+m})")
t = __mul_sizing(n, m)
return t.crop(t.fill((t, product)))

@classmethod
def sub(cls, tleft: TExp, tright: TExp) -> TExp:
Expand Down
10 changes: 8 additions & 2 deletions qlasskit/types/qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,18 @@ def bitwise_not(v: TExp) -> TExp:
@staticmethod
def shift_right(v: TExp, i: int = 1) -> TExp:
"""Apply a shift right"""
return (v[0], v[1][i:])
if not issubclass(v[0], Qtype):
raise TypeErrorException(v[0], Qtype)

return v[0].fill((v[0], v[1][i:]))

@staticmethod
def shift_left(v: TExp, i: int = 1) -> TExp:
"""Apply a shift left"""
return (v[0], [False] * i + v[1])
if not issubclass(v[0], Qtype):
raise TypeErrorException(v[0], Qtype)

return v[0].crop((v[0], [False] * i + v[1]))

@staticmethod
def add(tleft: TExp, tright: TExp) -> TExp:
Expand Down
42 changes: 38 additions & 4 deletions test/qlassf/test_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,27 @@ def test_composed_comparators(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_shift_left(self):
f = "def test(n: Qint[2]) -> Qint[4]: return n << 1"
@parameterized.expand(
[
(1,),
(2,),
(3,),
]
)
def test_shift_left(self, v):
f = f"def test(n: Qint[4]) -> Qint[4]: return n << {v}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_shift_right(self):
f = "def test(n: Qint[2]) -> Qint[4]: return n >> 1"
@parameterized.expand(
[
(1,),
(2,),
(3,),
]
)
def test_shift_right(self, v):
f = f"def test(n: Qint[2]) -> Qint[4]: return n >> {v}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

Expand Down Expand Up @@ -479,3 +493,23 @@ def test_mul5(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
self.assertEqual(qf.expressions[0][1], True)


@parameterized_class(
("ttype_i", "ttype_o", "const", "compiler"),
inject_parameterized_compilers(
[
(4, 6, 2),
(4, 6, 4),
(4, 6, 6),
(6, 8, 6),
(6, 8, 8),
(6, 8, 10),
]
),
)
class TestQlassfIntMulByConst(unittest.TestCase):
def test_mul(self):
f = f"def test(a: Qint[{self.ttype_i}]) -> Qint[{self.ttype_o}]: return a * {self.const}"
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def compute_and_compare_results(cls, qf, test_original_f=True, test_qcircuit=Tru

# circ_qi = qf.circuit().export("circuit", "qiskit")

# update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates)
update_statistics(qf.circuit().num_qubits, qf.circuit().num_gates)

# print(qf.expressions)
# print(circ_qi.draw("text"))
Expand Down

0 comments on commit 3498a6c

Please sign in to comment.