Skip to content

Commit

Permalink
track nullability (generic)
Browse files Browse the repository at this point in the history
ref #5 #76
  • Loading branch information
ygrek committed Feb 23, 2023
1 parent 044a667 commit 5002eb8
Show file tree
Hide file tree
Showing 17 changed files with 288 additions and 163 deletions.
1 change: 1 addition & 0 deletions TODO
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* WHERE IS NOT NULL influence type
* native-type annotations in queries
* allow to parametrize SQL syntax itself (ORDER BY ASC|DESC) : unsafe/enumeration/option params
* ocaml ppx syntax extension for inline sql
Expand Down
160 changes: 102 additions & 58 deletions lib/sql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ open Prelude

module Type =
struct
type t =
type kind =
| Unit of [`Interval]
| Int
| Text
Expand All @@ -16,19 +16,37 @@ struct
| Datetime
| Decimal
| Any
[@@deriving show {with_path=false}]
[@@deriving eq, show{with_path=false}]

let to_string = show
type nullability =
| Nullable (** can be NULL *)
| Strict (** cannot be NULL *)
| Depends (** unknown, to be determined *)
[@@deriving eq, show{with_path=false}]

let matches x y =
match x,y with
| Any, _ | _, Any -> true
| _ -> x = y
type t = { t : kind; nullability : nullability; }[@@deriving eq, show{with_path=false}]

let nullability nullability = fun t -> { t; nullability }
let strict = nullability Strict
let depends = nullability Depends
let nullable = nullability Nullable

let (=) : t -> t -> bool = equal

let show { t; nullability; } = show_kind t ^ (match nullability with Nullable -> "?" | Depends -> "??" | Strict -> "")
let _ = pp
let pp pf t = Format.pp_print_string pf (show t)

let is_unit = function Unit _ -> true | _ -> false
let type_name t = show_kind t.t

let order x y =
if x = y then
let is_any { t; nullability = _ } = equal_kind t Any

let matches x y = is_any x || is_any y || x = y

let is_unit = function { t = Unit _; _ } -> true | _ -> false

let order_kind x y =
if equal_kind x y then
`Equal
else
match x,y with
Expand All @@ -41,37 +59,59 @@ struct
| Text, Datetime | Datetime, Text -> `Order (Datetime,Text)
| _ -> `No

let common_type f x y =
match order x y with
| `Equal -> Some x
| `Order p -> Some (f p)
| `No -> None

let common_supertype = common_type snd
let common_subtype = common_type fst
let common_type x y = Option.is_some @@ common_subtype x y
let order_nullability x y =
match x,y with
| Depends, Depends -> `Equal Depends
| Nullable, Nullable -> `Equal Nullable
| Strict, Strict -> `Equal Strict
| Depends, n
| n, Depends -> `Equal n (* Order ? *)
| Strict, Nullable -> `Strict_Nullable
| Nullable, Strict -> `Nullable_Strict

let common_nullability = List.fold_left (fun acc t ->
match acc, t.nullability with
| _, Nullable
| Nullable, _ -> Nullable
| _, Strict
| Strict, _ -> Strict
| Depends, Depends -> Depends
) Depends

let common_nullability l = match common_nullability l with Depends -> Strict | n -> n
let undepend t nullability = if equal_nullability t.nullability Depends then { t with nullability } else t

let common_type x y =
match order_nullability x.nullability y.nullability, order_kind x.t y.t with
| _, `No -> None
| `Equal nullability, `Order (t,_) -> Some {t; nullability}
| `Equal _nullability, `Equal -> Some x
| (`Nullable_Strict|`Strict_Nullable), `Equal -> Some (nullable x.t)
| (`Nullable_Strict|`Strict_Nullable), `Order (sub,_) -> Some (nullable sub)

let has_common_type x y = Option.is_some @@ common_type x y

type tyvar = Typ of t | Var of int
let string_of_tyvar = function Typ t -> to_string t | Var i -> sprintf "'%c" (Char.chr @@ Char.code 'a' + i)
let string_of_tyvar = function Typ t -> show t | Var i -> sprintf "'%c" (Char.chr @@ Char.code 'a' + i)

type func =
| Group of t (* _ -> t *)
| Agg (* 'a -> 'a *)
| Multi of tyvar * tyvar (* 'a -> ... -> 'a -> 'b *)
| Ret of t (* _ -> t *) (* TODO eliminate *)
| Ret of kind (* _ -> t *) (* TODO eliminate *)
| F of tyvar * tyvar list

let monomorphic ret args = F (Typ ret, List.map (fun t -> Typ t) args)
let fixed = monomorphic
let fixed ret args = monomorphic (depends ret) (List.map depends args)

let identity = F (Var 0, [Var 0])

let pp_func pp =
let open Format in
function
| Agg -> fprintf pp "|'a| -> 'a"
| Group ret -> fprintf pp "|_| -> %s" (to_string ret)
| Ret ret -> fprintf pp "_ -> %s" (to_string ret)
| Group ret -> fprintf pp "|_| -> %s" (show ret)
| Ret ret -> fprintf pp "_ -> %s" (show_kind ret)
| F (ret, args) -> fprintf pp "%s -> %s" (String.concat " -> " @@ List.map string_of_tyvar args) (string_of_tyvar ret)
| Multi (ret, each_arg) -> fprintf pp "{ %s }+ -> %s" (string_of_tyvar each_arg) (string_of_tyvar ret)

Expand Down Expand Up @@ -155,7 +195,7 @@ struct

let sub l a = List.filter (fun x -> not (List.mem x a)) l

let to_string v = v |> List.map (fun attr -> sprintf "%s %s" (Type.to_string attr.domain) attr.name) |>
let to_string v = v |> List.map (fun attr -> sprintf "%s %s" (Type.show attr.domain) attr.name) |>
String.concat ", " |> sprintf "[%s]"
let names t = t |> List.map (fun attr -> attr.name) |> String.concat "," |> sprintf "[%s]"

Expand All @@ -176,14 +216,13 @@ struct
common @ sub t1 common @ sub t2 common

let check_types t1 t2 =
List.iter2 (fun a1 a2 ->
match a1.domain, a2.domain with
| Type.Any, _
| _, Type.Any -> ()
| x, y when x = y -> ()
| _ -> raise (Error (t1, sprintf "Atributes do not match : %s of type %s and %s of type %s"
a1.name (Type.to_string a1.domain)
a2.name (Type.to_string a2.domain)))) t1 t2
List.iter2 begin fun a1 a2 ->
match Type.matches a1.domain a2.domain with
| true -> ()
| false -> raise (Error (t1, sprintf "Atributes do not match : %s of type %s and %s of type %s"
a1.name (Type.show a1.domain)
a2.name (Type.show a2.domain)))
end t1 t2

let check_types t1 t2 =
try check_types t1 t2 with
Expand Down Expand Up @@ -231,7 +270,7 @@ type table = table_name * schema [@@deriving show]
let print_table out (name,schema) =
IO.write_line out (show_table_name name);
schema |> List.iter begin fun {name;domain;extra} ->
IO.printf out "%10s %s %s\n" (Type.to_string domain) name (Constraints.show extra)
IO.printf out "%10s %s %s\n" (Type.show domain) name (Constraints.show extra)
end;
IO.write_line out ""

Expand Down Expand Up @@ -344,7 +383,7 @@ type stmt =
| Update of table_name * assignments * expr option * order * param list (* where, order, limit *)
| UpdateMulti of source list * assignments * expr option
| Select of select_full
| CreateRoutine of string * Type.t option * (string * Type.t * expr option) list
| CreateRoutine of string * Type.kind option * (string * Type.kind * expr option) list

(*
open Schema
Expand Down Expand Up @@ -384,7 +423,7 @@ let exclude narg name = add_ (Some narg) None name
let add_multi typ name = add_ None (Some typ) name
let add narg typ name = add_ (Some narg) (Some typ) name

let sponge = Type.(Multi (Typ Any, Typ Any))
let sponge = let open Type in let any = depends Any in Multi (Typ any, Typ any)

let lookup name narg =
let name = String.lowercase name in
Expand All @@ -411,31 +450,36 @@ let () =
let open Type in
let open Function in
let (||>) x f = List.iter f x in
"count" |> add 0 (Group Int); (* count( * ) - asterisk is treated as no parameters in parser *)
"count" |> add 1 (Group Int);
"avg" |> add 1 (Group Float);
let int = strict Int in
let float = strict Float in
let text = strict Text in
let datetime = strict Datetime in
"count" |> add 0 (Group int); (* count( * ) - asterisk is treated as no parameters in parser *)
"count" |> add 1 (Group int);
"avg" |> add 1 (Group float);
["max";"min";"sum"] ||> add 1 Agg;
["max";"min"] ||> multi_polymorphic; (* sqlite3 *)
["lower";"upper"] ||> monomorphic Text [Text];
"length" |> monomorphic Int [Text];
["random"] ||> monomorphic Int [];
"floor" |> monomorphic Int [Float];
["nullif";"ifnull"] ||> add 2 (F (Var 0, [Var 0; Var 0]));
["lower";"upper"] ||> monomorphic text [text];
"length" |> monomorphic int [text];
["random"] ||> monomorphic int [];
"floor" |> monomorphic int [float];
"nullif" |> add 2 (F (Var 0 (* TODO nullable *), [Var 0; Var 0]));
"ifnull" |> add 2 (F (Var 0, [Var 1; Var 0]));
["least";"greatest";"coalesce"] ||> multi_polymorphic;
"strftime" |> exclude 1; (* requires at least 2 arguments *)
["concat";"strftime"] ||> multi ~ret:(Typ Text) (Typ Text);
["date";"time"] ||> monomorphic Text [Datetime];
"julianday" |> multi ~ret:(Typ Float) (Typ Text);
"from_unixtime" |> monomorphic Datetime [Int];
"from_unixtime" |> monomorphic Text [Int;Text];
["pow"; "power"] ||> monomorphic Float [Float;Int];
"unix_timestamp" |> monomorphic Int [];
"unix_timestamp" |> monomorphic Int [Datetime];
["timestampdiff";"timestampadd"] ||> monomorphic Int [Unit `Interval;Datetime;Datetime];
["concat";"strftime"] ||> multi ~ret:(Typ text) (Typ text);
["date";"time"] ||> monomorphic text [datetime];
"julianday" |> multi ~ret:(Typ float) (Typ text);
"from_unixtime" |> monomorphic datetime [int];
"from_unixtime" |> monomorphic text [int;text];
["pow"; "power"] ||> monomorphic float [float;int];
"unix_timestamp" |> monomorphic int [];
"unix_timestamp" |> monomorphic int [datetime];
["timestampdiff";"timestampadd"] ||> monomorphic int [strict @@ Unit `Interval;datetime;datetime];
"any_value" |> add 1 (F (Var 0,[Var 0])); (* 'a -> 'a but not aggregate *)
"substring" |> monomorphic Text [Text; Int];
"substring" |> monomorphic Text [Text; Int; Int];
"substring_index" |> monomorphic Text [Text; Text; Int];
"last_insert_id" |> monomorphic Int [];
"last_insert_id" |> monomorphic Int [Int];
"substring" |> monomorphic text [text; int];
"substring" |> monomorphic text [text; int; int];
"substring_index" |> monomorphic text [text; text; int];
"last_insert_id" |> monomorphic int [];
"last_insert_id" |> monomorphic int [int];
()
62 changes: 35 additions & 27 deletions lib/sql_parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
let make_limit l =
let param = function
| _, `Const _ -> None
| x, `Param { label=None; pos } -> Some (new_param { label = Some (match x with `Limit -> "limit" | `Offset -> "offset"); pos } Int)
| _, `Param id -> Some (new_param id Int)
| x, `Param { label=None; pos } -> Some (new_param { label = Some (match x with `Limit -> "limit" | `Offset -> "offset"); pos } (strict Int))
| _, `Param id -> Some (new_param id (strict Int))
in
List.filter_map param l, List.mem (`Limit,`Const 1) l

Expand Down Expand Up @@ -297,7 +297,12 @@ alter_pos: AFTER col=IDENT { `After col }
| { `Default }
drop_behavior: CASCADE | RESTRICT { }

column_def: name=IDENT t=sql_type? extra=column_def_extra* { make_attribute name (Option.default Int t) (Constraints.of_list @@ List.filter_map identity extra) }
column_def: name=IDENT t=sql_type? extra=column_def_extra*
{
let extra = Constraints.of_list @@ List.filter_map identity extra in
let t = { t = Option.default Int t; nullability = if Constraints.mem Null extra then Nullable else Strict } in
make_attribute name t extra
}

column_def1: c=column_def { `Attr c }
| pair(CONSTRAINT,IDENT?)? l=table_constraint_1 index_options { `Constraint l }
Expand Down Expand Up @@ -332,7 +337,7 @@ column_def_extra: PRIMARY KEY { Some PrimaryKey }
| AUTOINCREMENT { Some Autoincrement }
| on_conflict { None }
| CHECK LPAREN expr RPAREN { None }
| DEFAULT e=default_value { if e = Value Any then Some Null else None } (* FIXME check type with column *)
| DEFAULT e=default_value { match e with Value { t = Any; nullability = _ } -> Some Null | _ -> None } (* FIXME check type with column *)
| COLLATE IDENT { None }

default_value: e=single_literal_value | e=datetime_value { e } (* sub expr ? *)
Expand All @@ -359,7 +364,7 @@ expr:
| e1=expr NUM_DIV_OP e2=expr %prec PLUS { Fun ((Ret Float),[e1;e2]) }
| e1=expr DIV e2=expr %prec PLUS { Fun ((Ret Int),[e1;e2]) }
| e1=expr boolean_bin_op e2=expr %prec AND { Fun ((fixed Bool [Bool;Bool]),[e1;e2]) }
| e1=expr comparison_op anyall? e2=expr %prec EQUAL { poly Bool [e1;e2] }
| e1=expr comparison_op anyall? e2=expr %prec EQUAL { poly (depends Bool) [e1;e2] }
| e1=expr CONCAT_OP e2=expr { Fun ((fixed Text [Text;Text]),[e1;e2]) }
| e=like_expr esc=escape?
{
Expand All @@ -375,16 +380,16 @@ expr:
| VALUES LPAREN n=IDENT RPAREN { Inserted n }
| v=literal_value | v=datetime_value { v }
| v=interval_unit { v }
| e1=expr mnot(IN) l=sequence(expr) { poly Bool (e1::l) }
| e1=expr mnot(IN) LPAREN select=select_stmt RPAREN { poly Bool [e1; SelectExpr (select, `AsValue)] }
| e1=expr mnot(IN) l=sequence(expr) { poly (depends Bool) (e1::l) }
| e1=expr mnot(IN) LPAREN select=select_stmt RPAREN { poly (depends Bool) [e1; SelectExpr (select, `AsValue)] }
| e1=expr IN table=table_name { Tables.check table; e1 }
| e1=expr k=in_or_not_in p=PARAM
{
let e = poly Bool [ e1; Inparam (new_param p Any) ] in
let e = poly (depends Bool) [ e1; Inparam (new_param p (depends Any)) ] in
InChoice ({ label = p.label; pos = ($startofs, $endofs) }, k, e )
}
| LPAREN select=select_stmt RPAREN { SelectExpr (select, `AsValue) }
| p=PARAM { Param (new_param p Any) }
| p=PARAM { Param (new_param p (depends Any)) }
| p=PARAM parser_state_ident LCURLY l=choices c2=RCURLY { let { label; pos=(p1,_p2) } = p in Choices ({ label; pos = (p1,c2+1)},l) }
| SUBSTRING LPAREN s=expr FROM p=expr FOR n=expr RPAREN
| SUBSTRING LPAREN s=expr COMMA p=expr COMMA n=expr RPAREN { Fun (Function.lookup "substring" 3, [s;p;n]) }
Expand All @@ -396,22 +401,22 @@ expr:
| CONVERT LPAREN e=expr COMMA t=sql_type RPAREN
| CAST LPAREN e=expr AS t=sql_type RPAREN { Fun (Ret t, [e]) }
| f=IDENT LPAREN p=func_params RPAREN { Fun (Function.lookup f (List.length p), p) }
| e=expr IS NOT? NULL { Fun (Ret Bool, [e]) }
| e1=expr IS NOT? distinct_from? e2=expr { poly Bool [e1;e2] }
| e=expr mnot(BETWEEN) a=expr AND b=expr { poly Bool [e;a;b] }
| mnot(EXISTS) LPAREN select=select_stmt RPAREN { Fun (F (Typ Bool, [Typ Any]),[SelectExpr (select,`Exists)]) }
| e=expr IS NOT? NULL { poly (strict Bool) [e] }
| e1=expr IS NOT? distinct_from? e2=expr { poly (strict Bool) [e1;e2] }
| e=expr mnot(BETWEEN) a=expr AND b=expr { poly (depends Bool) [e;a;b] }
| mnot(EXISTS) LPAREN select=select_stmt RPAREN { Fun (F (Typ (strict Bool), [Typ (depends Any)]),[SelectExpr (select,`Exists)]) }
| CASE e1=expr? branches=nonempty_list(case_branch) e2=preceded(ELSE,expr)? END (* FIXME typing *)
{
let t_args =
match e1 with
| None -> (List.flatten @@ List.map (fun _ -> [Typ Bool; Var 1]) branches)
| None -> (List.flatten @@ List.map (fun _ -> [Typ (depends Bool); Var 1]) branches)
| Some _ -> [Var 0] @ (List.flatten @@ List.map (fun _ -> [Var 0; Var 1]) branches)
in
let t_args = t_args @ maybe (fun _ -> Var 1) e2 in
let v_args = option_to_list e1 @ List.flatten branches @ option_to_list e2 in
Fun (F (Var 1, t_args), v_args)
}
| IF LPAREN e1=expr COMMA e2=expr COMMA e3=expr RPAREN { Fun (F (Var 0, [Typ Bool;Var 0;Var 0]), [e1;e2;e3]) }
| IF LPAREN e1=expr COMMA e2=expr COMMA e3=expr RPAREN { Fun (F (Var 0, [Typ (depends Bool);Var 0;Var 0]), [e1;e2;e3]) }
| e=window_function OVER window_spec { e }

(* https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html *)
Expand Down Expand Up @@ -442,24 +447,27 @@ choice_body: c1=LCURLY e=expr c2=RCURLY { (c1,Some e,c2) }
choice: parser_state_normal label=IDENT? e=choice_body? { let (c1,e,c2) = Option.default (0,None,0) e in ({ label; pos = (c1+1,c2) },e) }
choices: separated_nonempty_list(pair(parser_state_ident,NUM_BIT_OR),choice) { $1 }

datetime_value: | DATETIME_FUNC | DATETIME_FUNC LPAREN INTEGER? RPAREN { Value Datetime }
datetime_value: | DATETIME_FUNC | DATETIME_FUNC LPAREN INTEGER? RPAREN { Value (strict Datetime) }

literal_value:
| TEXT collate? { Value Text }
| BLOB collate? { Value Blob }
| INTEGER { Value Int }
| FLOAT { Value Float }
strict_value:
| TEXT collate? { Text }
| BLOB collate? { Blob }
| INTEGER { Int }
| FLOAT { Float }
| TRUE
| FALSE { Value Bool }
| FALSE { Bool }
| DATE TEXT
| TIME TEXT
| TIMESTAMP TEXT { Value Datetime }
| NULL { Value Any } (* he he *)
| TIMESTAMP TEXT { Datetime }

literal_value:
| strict_value { Value (strict $1) }
| NULL { Value (nullable Any) } (* he he *)

single_literal_value:
| literal_value { $1 }
| MINUS INTEGER { Value Int }
| MINUS FLOAT { Value Float }
| MINUS INTEGER { Value (strict Int) }
| MINUS FLOAT { Value (strict Float) }

expr_list: l=commas(expr) { l }
func_params: DISTINCT? l=expr_list { l }
Expand All @@ -478,7 +486,7 @@ interval_unit: MICROSECOND | SECOND | MINUTE | HOUR | DAY | WEEK | MONTH | QUART
| SECOND_MICROSECOND | MINUTE_MICROSECOND | MINUTE_SECOND
| HOUR_MICROSECOND | HOUR_SECOND | HOUR_MINUTE
| DAY_MICROSECOND | DAY_SECOND | DAY_MINUTE | DAY_HOUR
| YEAR_MONTH { Value (Unit `Interval) }
| YEAR_MONTH { Value (strict (Unit `Interval)) }

sql_type_flavor: T_INTEGER UNSIGNED? ZEROFILL? { Int }
| T_DECIMAL { Decimal }
Expand Down
Loading

0 comments on commit 5002eb8

Please sign in to comment.