Skip to content

Commit

Permalink
cleanup: oo is now ++, dim is now height, def of T
Browse files Browse the repository at this point in the history
  • Loading branch information
melsman committed Jan 16, 2025
1 parent dfe5ebf commit cac388f
Show file tree
Hide file tree
Showing 14 changed files with 232 additions and 75 deletions.
13 changes: 7 additions & 6 deletions src/circuit.sig
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
signature CIRCUIT = sig

datatype t = I | X | Y | Z | H | SW
datatype t = I | X | Y | Z | H | T | SW
| C of t
| Tensor of t * t
| Seq of t * t

val oo : t * t -> t
val ** : t * t -> t
val ++ : t * t -> t
val ** : t * t -> t

val pp : t -> string
val draw : t -> string
val dim : t -> int
val pp : t -> string
val draw : t -> string
val draw_latex : t -> string
val height : t -> int

end
42 changes: 33 additions & 9 deletions src/circuit.sml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@ structure Circuit : CIRCUIT = struct

infix |> fun a |> f = f a

datatype t = I | X | Y | Z | H | SW
datatype t = I | X | Y | Z | H | T | SW
| Tensor of t * t
| Seq of t * t
| C of t

val oo = op Seq
val ++ = op Seq
val ** = op Tensor

fun pp t =
let fun maybePar P s = if P then "(" ^ s ^ ")" else s
fun pp p t =
case t of
Tensor(t1,t2) => maybePar (p > 4) (pp 4 t1 ^ " ** " ^ pp 4 t2)
| Seq(t1,t2) => maybePar (p > 3) (pp 3 t1 ^ " oo " ^ pp 3 t2)
| Seq(t1,t2) => maybePar (p > 3) (pp 3 t1 ^ " ++ " ^ pp 3 t2)
| C t => "C" ^ pp 8 t
| I => "I" | X => "X" | Y => "Y" | Z => "Z" | H => "H" | SW => "SW"
| I => "I" | X => "X" | Y => "Y" | Z => "Z" | H => "H" | T => "T" | SW => "SW"
in pp 0 t
end

Expand All @@ -32,6 +32,7 @@ structure Circuit : CIRCUIT = struct
| Y => Diagram.box "Y"
| Z => Diagram.box "Z"
| H => Diagram.box "H"
| T => Diagram.box "T"
| C X => Diagram.cntrl "X"
| C Y => Diagram.cntrl "Y"
| C Z => Diagram.cntrl "Z"
Expand All @@ -41,17 +42,40 @@ structure Circuit : CIRCUIT = struct
in dr t |> Diagram.toString
end

fun dim t =
structure DiagramL = DiagramLatex

fun draw_latex t =
let fun dr t =
case t of
SW => DiagramL.swap
| Tensor(a,b) => DiagramL.par(dr a, dr b)
| Seq(a,b) => DiagramL.seq(dr a, dr b)
| I => DiagramL.line
| X => DiagramL.box "X"
| Y => DiagramL.box "Y"
| Z => DiagramL.box "Z"
| H => DiagramL.box "H"
| T => DiagramL.box "T"
| C X => DiagramL.cntrl "X"
| C Y => DiagramL.cntrl "Y"
| C Z => DiagramL.cntrl "Z"
| C H => DiagramL.cntrl "H"
| C _ => raise Fail ("Circuit.draw_latex: Controlled circuit " ^
pp t ^ " cannot be drawn")
in dr t |> DiagramL.toString
end

fun height t =
case t of
Tensor(a,b) => dim a + dim b
Tensor(a,b) => height a + height b
| Seq(a,b) =>
let val d = dim a
in if d <> dim b
let val d = height a
in if d <> height b
then raise Fail "Sequence error: mismatching dimensions"
else d
end
| SW => 2
| C t => 1 + dim t
| C t => 1 + height t
| _ => 1

end
36 changes: 21 additions & 15 deletions src/comp.sml
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,30 @@ structure Comp :> COMP = struct
| Y => ret (APP("Y",[]))
| Z => ret (APP("Z",[]))
| H => ret (APP("H",[]))
| T => ret (APP("T",[]))
| SW => ret (APP("SW",[]))
| Seq(t1,t2) =>
comp t1 >>= (fn e1 =>
comp t2 >>= (fn e2 =>
let val n = Int.toString (pow2 (dim t))
let val n = Int.toString (pow2 (height t))
val ty = "[" ^ n ^ "][" ^ n ^ "]C.complex"
in ret (TYPED(APP("matmul", [e2,e1]),ty))
end))
| Tensor(t1,t2) =>
comp t1 >>= (fn e1 =>
comp t2 >>= (fn e2 =>
let val n = Int.toString (pow2 (dim t))
let val n = Int.toString (pow2 (height t))
val ty = "[" ^ n ^ "][" ^ n ^ "]C.complex"
in ret (TYPED(APP("tensor", [e1,e2]),ty))
end))
| C t' => comp t' >>= (fn e =>
let val n = Int.toString (pow2 (dim t))
let val n = Int.toString (pow2 (height t))
val ty = "[" ^ n ^ "][" ^ n ^ "]C.complex"
in ret (TYPED(APP("control",[e]),ty))
end)
end

fun vecTyFromDim d =
fun vecTyFromHeight d =
"[" ^ Int.toString(pow2 d) ^ "]C.complex"

local
Expand All @@ -73,15 +74,15 @@ structure Comp :> COMP = struct
end
fun FunC A (f:F.exp -> F.exp F.M) : F.var option F.M =
if allI A then ret NONE
else let val ty = vecTyFromDim (Circuit.dim A)
else let val ty = vecTyFromHeight (Circuit.height A)
in Fun f ty ty >>= (ret o SOME)
end
fun splitF d v =
let val ty = "[" ^ Int.toString (pow2 d) ^ "+" ^
Int.toString (pow2 d) ^ "]C.complex"
in APP("split",[TYPED(v,ty)])
end
fun concatF d a b = TYPED(APP("concat",[a,b]),vecTyFromDim d)
fun concatF d a b = TYPED(APP("concat",[a,b]),vecTyFromHeight d)
fun unvecF (e,ty) = APP("unvec", [TYPED(e,ty)])
fun vecF e = APP("vec",[e])
fun mapF f e = APP("map", [VAR f,e])
Expand All @@ -95,25 +96,26 @@ structure Comp :> COMP = struct
Circuit.I => ret v
| Circuit.Seq(t1,t2) => icomp t1 v >>= (icomp t2)
| Circuit.C t' =>
Let (splitF (Circuit.dim t') v) >>= (fn p =>
Let (splitF (Circuit.height t') v) >>= (fn p =>
icomp t' (SEL(1,VAR p)) >>= (fn v1 =>
ret (concatF (Circuit.dim t) (SEL(0,VAR p)) v1)))
ret (concatF (Circuit.height t) (SEL(0,VAR p)) v1)))
| Circuit.Tensor(A,B) =>
FunC A (icomp A) >>= (fn Af =>
FunC B (icomp B) >>= (fn Bf =>
let val dA = pow2(Circuit.dim A)
val dB = pow2(Circuit.dim B)
let val dA = pow2(Circuit.height A)
val dB = pow2(Circuit.height B)
val ty = "[" ^ Int.toString dA ^ "*" ^
Int.toString dB ^ "]C.complex"
in Let (unvecF(v,ty)) >>= (fn V =>
Let (mapF' Bf (transposeF (VAR V))) >>= (fn W =>
Let (mapF' Af (transposeF (VAR W))) >>= (fn Y =>
ret (TYPED(vecF (VAR Y),vecTyFromDim (Circuit.dim t))))))
ret (TYPED(vecF (VAR Y),vecTyFromHeight (Circuit.height t))))))
end))
| Circuit.H => ret (matvecmulF (APP("H",[])) v)
| Circuit.X => ret (matvecmulF (APP("X",[])) v)
| Circuit.Y => ret (matvecmulF (APP("Y",[])) v)
| Circuit.Z => ret (matvecmulF (APP("Z",[])) v)
| Circuit.H => ret (matvecmulF (APP("H",[])) v)
| Circuit.T => ret (matvecmulF (APP("T",[])) v)
| Circuit.SW => ret (matvecmulF (APP("SW",[])) v)
end

Expand All @@ -126,20 +128,24 @@ structure Comp :> COMP = struct
val cni = APP("C.mk_im", [CONST "(-1)"])
val rsqrt2 = APP("C.mk_re", [CONST "(1.0 / f64.sqrt(2.0))"])
val rnsqrt2 = APP("C.mk_re", [CONST "((-1.0) / f64.sqrt(2.0))"])
val tmp = APP("C.exp", [APP("C.mk_im",[CONST "(f64.pi/4)"])])
val rsqrt2eipi4 = APP("C.*", [rsqrt2,tmp])
fun ty n = "[" ^ Int.toString n ^ "][" ^ Int.toString n ^ "]C.complex"
fun binds nil = ret ()
| binds ((s,n,e)::rest) =
FunNamed s (fn _ => ret e) "()" (ty n) >>= (fn _ => binds rest)
in binds [("I", 2, ARR[ARR[c1,c0],
ARR[c0,c1]]),
("H", 2, ARR[ARR[rsqrt2,rsqrt2],
ARR[rsqrt2,rnsqrt2]]),
("X", 2, ARR[ARR[c0,c1],
ARR[c1,c0]]),
("Y", 2, ARR[ARR[c0,cni],
ARR[ci,c0]]),
("Z", 2, ARR[ARR[c1,c0],
ARR[c0,cn1]]),
("H", 2, ARR[ARR[rsqrt2,rsqrt2],
ARR[rsqrt2,rnsqrt2]]),
("T", 2, ARR[ARR[rsqrt2,c0],
ARR[c0,rsqrt2eipi4]]),
("SW", 4, ARR[ARR[c1,c0,c0,c0],
ARR[c0,c0,c1,c0],
ARR[c0,c1,c0,c0],
Expand All @@ -156,7 +162,7 @@ structure Comp :> COMP = struct

fun circuitToFutFunBind (f:string) (t:Circuit.t) : string =
let open F infix >>=
val ty = vecTyFromDim (Circuit.dim t)
val ty = vecTyFromHeight (Circuit.height t)
in runBinds (FunNamed f (icomp t) ty ty >>= (fn _ =>
ret()))
end
Expand Down
124 changes: 124 additions & 0 deletions src/diagram-latex.sml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
structure DiagramLatex :> DIAGRAM = struct

datatype t = Box of string
| Line
| Cntrl of string
| Swap
| Par of t * t
| Seq of t * t

fun depth t : int =
case t of
Line => 1
| Box _ => 1
| Cntrl _ => 1
| Swap => 1
| Par (t1,t2) => Int.max(depth t1, depth t2)
| Seq(t1,t2) => depth t1 + depth t2

fun height t : int =
case t of
Line => 1
| Box _ => 1
| Cntrl _ => 2
| Swap => 2
| Par (t1,t2) => height t1 + height t2
| Seq(t1,t2) => Int.max(height t1, height t2)

val dy = 10
val dx = 10

fun i2s x = if x < 0 then "-" ^ i2s (~x)
else Int.toString x

fun put (x,y) c =
"\\put(" ^ i2s x ^ "," ^ i2s y ^ "){" ^ c ^ "}"

fun circ () = "\\circle*{" ^ i2s (dy div 10) ^ "}"

fun line (x,y) l =
"\\line(" ^ i2s x ^ "," ^ i2s y ^ "){" ^ i2s l ^ "}"

fun framebox (sx,sy) s =
"\\framebox(" ^ i2s sx ^ "," ^ i2s sy ^ "){" ^ s ^ "}"

fun put_line (x,y) a =
put (x,y + dy div 2) (line (1,0) dx) :: a

fun put_swap (x,y) a =
let val (x1,y1) = (x + dx div 2, y + dy div 2)
val (x2,y2) = (x1,y1-dy)
in put_line (x,y)
(put_line (x,y-dy)
(put (x1,y2) (line (0,1) dy) ::
put (x1,y1) (circ()) ::
put (x2,y2) (circ()) :: a))
end

fun put_cntrl (x,y) a =
let val (x1,y1) = (x + dx div 2, y + dy div 2)
val dy' = dy - 3 * dy div 10
in put_line (x,y)
(put (x1,y1) (line (0,~1) dy') ::
put (x1,y1) (circ()) :: a)
end

fun put_box s (x,y) a =
let val dx' = dx div 5
val x' = x + dx'
val dy' = dy div 5
val y' = y + dy'
val sx = dx - 2 * dx'
val sy = dy - 2 * dy'
in put (x',y') (framebox(sx,sy) s) ::
put (x,y + dy div 2) (line(1,0) dx') ::
put (x+dx,y + dy div 2) (line(~1,0) dx') :: a
end

fun lines n =
if n > 1 then Par(Line,lines(n-1))
else Line

fun padl t =
Seq(lines (height t), t)

fun padr t =
Seq(t,lines (height t))

fun toStr x y t a =
case t of
Box s => put_box s (x,y) a
| Line => put_line (x,y) a
| Swap => put_swap (x,y) a
| Cntrl s => put_box s (x,y - dy) (put_cntrl (x,y) a)
| Seq (t1,t2) => toStr (x + dx*(depth t1)) y t2 (toStr x y t1 a)
| Par (t1,t2) =>
let val d1 = depth t1
val d2 = depth t2
in if d1 > d2 + 1 then
toStr x y (Par(t1,padl (padr t2))) a
else if d1 > d2 then
toStr x y (Par(t1,padl t2)) a
else if d2 > d1 + 1 then
toStr x y (Par(padl (padr t1),t2)) a
else if d2 > d1 then
toStr x y (Par(padl t1,t2)) a
else
toStr x (y - dy*(height t1)) t2 (toStr x y t1 a)
end

fun toString t =
let val (h,d) = (height t, depth t)
in String.concatWith "\n"
("\\begin{picture}(" ^ i2s (dx*d) ^ "," ^ i2s (dy*h) ^ ")(0,0)" ::
toStr 0 ((h-1)*dy) t ["\\end{picture}"])
end

val box = Box
val line = Line
val cntrl = Cntrl
val swap = Swap
val seq = Seq
val par = Par

end
4 changes: 4 additions & 0 deletions src/diagram.mlb
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
local $(SML_LIB)/basis/basis.mlb
in diagram.sml
diagram-latex.sml
end
2 changes: 1 addition & 1 deletion src/diagram.sml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ sig
val toString : t -> string
end

structure Diagram : DIAGRAM =
structure Diagram :> DIAGRAM =
struct
type t = string list (* lines; invariant: lines have equal size *)

Expand Down
2 changes: 1 addition & 1 deletion src/quantum.mlb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local $(SML_LIB)/basis/basis.mlb
../lib/github.com/diku-dk/sml-matrix/matrix.mlb
../lib/github.com/diku-dk/sml-complex/complex.mlb
in diagram.sml
in diagram.mlb
circuit.sig
circuit.sml
semantics.sig
Expand Down
5 changes: 3 additions & 2 deletions src/quantum_ex1.sml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@

open Circuit Semantics
infix 3 oo
infix 3 ++
infix 4 **

fun run c k =
(print ("Circuit for c = " ^ pp c ^ ":\n");
print (draw c ^ "\n");
print (draw_latex c ^ "\n");
print ("Semantics of c:\n" ^ pp_mat(sem c) ^ "\n");
print ("Result distribution when evaluating c on " ^ pp_ket k ^ " :\n");
let val v0 = init k
Expand All @@ -15,4 +16,4 @@ fun run c k =
; print ("V2: " ^ pp_state (interp c v0) ^ "\n")
end)

val () = run ((I ** H oo C X oo Z ** Z oo C X oo I ** H) ** I oo I ** SW oo C X ** Y) (ket[1,0,1])
val () = run ((I ** H ++ C X ++ Z ** Z ++ C X ++ I ** H) ** I ++ I ** SW ++ C X ** Y) (ket[1,0,1])
Loading

0 comments on commit cac388f

Please sign in to comment.