Skip to content

Commit

Permalink
fix typing in ast2ast
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Apr 11, 2024
1 parent a685084 commit 74c61ab
Showing 1 changed file with 44 additions and 20 deletions.
64 changes: 44 additions & 20 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,13 @@ def visit_Name(self, node):

def _replace_types_annotations(ann, arg=None):
"""Replaces type annotations, translating high level types"""
if isinstance(ann, ast.Subscript) and ann.value.id == "Tuple": # type: ignore
_elts = ann.slice.elts # type: ignore
if (
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Tuple"
and hasattr(ann.slice, 'elts')
):
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[_replace_types_annotations(el) for el in _elts])

ann = ast.Subscript(
Expand All @@ -92,8 +97,13 @@ def _replace_types_annotations(ann, arg=None):
)

# Replace Qlist[T,n] with Tuple[(T,)*n]
if isinstance(ann, ast.Subscript) and ann.value.id == "Qlist": # type: ignore
_elts = ann.slice.elts # type: ignore
if (
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qlist"
and hasattr(ann.slice, 'elts')
):
_elts = ann.slice.elts
_ituple = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[1].value)

ann = ast.Subscript(
Expand All @@ -102,8 +112,13 @@ def _replace_types_annotations(ann, arg=None):
)

# Replace Qmatrix[T,n,m] with Tuple[(Tuple[(T,)*m],)*n]
if isinstance(ann, ast.Subscript) and ann.value.id == "Qmatrix": # type: ignore
_elts = ann.slice.elts # type: ignore
if (
isinstance(ann, ast.Subscript)
and isinstance(ann.value, ast.Name)
and ann.value.id == "Qmatrix"
and hasattr(ann.slice, 'elts')
):
_elts = ann.slice.elts
_ituple_row = ast.Tuple(elts=[copy.deepcopy(_elts[0])] * _elts[2].value)
_ituple = ast.Tuple(elts=[copy.deepcopy(_ituple_row)] * _elts[1].value)

Expand Down Expand Up @@ -232,10 +247,10 @@ def visit_If(self, node):
if not isinstance(b, ast.Assign):
raise Exception("if body only allows assigns: ", ast.dump(b))

if len(b.targets) != 1:
raise Exception("if targets only allow one: ", ast.dump(b))
if len(b.targets) != 1 or not isinstance(b.targets[0], ast.Name):
raise Exception("if targets only allow one Name target: ", ast.dump(b))

target_0id = b.targets[0].id # type: ignore
target_0id = b.targets[0].id

if target_0id[0:2] == "__" and target_0id not in self.env:
orelse_inner = ast.Name(id=target_0id[2:])
Expand All @@ -255,10 +270,10 @@ def visit_If(self, node):
if not isinstance(b, ast.Assign):
raise Exception("if body only allows assigns: ", ast.dump(b))

if len(b.targets) != 1:
raise Exception("if targets only allow one: ", ast.dump(b))
if len(b.targets) != 1 or not isinstance(b.targets[0], ast.Name):
raise Exception("if targets only allow one Name target: ", ast.dump(b))

target_0id = b.targets[0].id # type: ignore
target_0id = b.targets[0].id

if target_0id[0:2] == "__" and target_0id not in self.env:
orelse_inner = ast.Name(id=target_0id[2:])
Expand Down Expand Up @@ -406,27 +421,36 @@ def visit_For(self, node): # noqa: C901
]
elif isinstance(iter, ast.Tuple):
iter = iter.elts
elif isinstance(iter, ast.Subscript) and iter.value.id in self.env: # type: ignore
if isinstance(self.env[iter.value.id], ast.Tuple): # type: ignore
iter = self.env[iter.value.id].elts[iter.slice.value] # type: ignore
elif (
isinstance(iter, ast.Subscript)
and isinstance(iter.value, ast.Name)
and iter.value.id in self.env
and hasattr(iter.slice, 'value')
):
if isinstance(self.env[iter.value.id], ast.Tuple):
new_iter = self.env[iter.value.id].elts[iter.slice.value]

elif isinstance(self.env[iter.value.id], ast.Subscript): # type: ignore
_elts = self.env[iter.value.id].slice.elts[iter.slice.value] # type: ignore
elif isinstance(self.env[iter.value.id], ast.Subscript):
_elts = self.env[iter.value.id].slice.elts[iter.slice.value]

if isinstance(_elts, ast.Tuple):
_elts = _elts.elts

iter = [
new_iter = [
ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=iter.value.id, ctx=ast.Load()), # type: ignore
slice=ast.Constant(value=iter.slice.value), # type: ignore
value=ast.Name(id=iter.value.id, ctx=ast.Load()),
slice=ast.Constant(value=iter.slice.value),
ctx=ast.Load(),
),
slice=ast.Constant(value=e),
)
for e in range(len(_elts))
]
else:
new_iter = iter

iter = new_iter

if isinstance(iter, ast.Constant) and isinstance(iter.value, ast.Tuple):
iter = iter.value.elts
Expand Down

0 comments on commit 74c61ab

Please sign in to comment.