diff --git a/examples/translation/dune b/examples/translation/dune index dbefff8..f76f595 100644 --- a/examples/translation/dune +++ b/examples/translation/dune @@ -2,5 +2,6 @@ (modes byte exe) (names seq2seq) (libraries stdio torch unix) + (flags :standard -w -32) (preprocess (pps ppx_jane))) diff --git a/src/gen_bindings/gen.ml b/src/gen_bindings/gen.ml index 75e1fed..b0540ee 100644 --- a/src/gen_bindings/gen.ml +++ b/src/gen_bindings/gen.ml @@ -132,7 +132,6 @@ module Func = struct type arg = { arg_name : string ; arg_type : arg_type - ; default_value : string option ; is_const : bool } @@ -481,17 +480,13 @@ let read_yaml filename = Map.find arg "is_nullable" |> Option.value_map ~default:false ~f:extract_bool in - let default_value = - Map.find arg "default" |> Option.map ~f:extract_string - in + let has_default_value = Map.find arg "default" |> Option.is_some in match Func.arg_type_of_string arg_type ~is_nullable with - | Some Scalar when Option.is_some default_value && not is_nullable -> None + | Some Scalar when has_default_value && not is_nullable -> None | Some TensorOptions - when Option.is_some default_value && Set.mem no_tensor_options name -> - None - | Some arg_type -> Some { Func.arg_name; arg_type; default_value; is_const } - | None -> - if Option.is_some default_value then None else raise Not_a_simple_arg) + when has_default_value && Set.mem no_tensor_options name -> None + | Some arg_type -> Some { Func.arg_name; arg_type; is_const } + | None -> if has_default_value then None else raise Not_a_simple_arg) in Some { Func.name; operator_name; overload_name; args; returns; kind } with @@ -690,9 +685,7 @@ let methods = ; kind = `method_ } in - let ca arg_name arg_type = - { Func.arg_name; arg_type; default_value = None; is_const = true } - in + let ca arg_name arg_type = { Func.arg_name; arg_type; is_const = true } in [ c "grad" [ ca "self" Tensor ] ; c "set_requires_grad" [ ca "self" Tensor; ca "r" Bool ] ; c "toType" [ ca "self" Tensor; ca "scalar_type" ScalarType ] diff --git a/src/torch/optimizer.ml b/src/torch/optimizer.ml index 9e178d0..88a0e77 100644 --- a/src/torch/optimizer.ml +++ b/src/torch/optimizer.ml @@ -143,9 +143,16 @@ module Linear_interpolation = struct then t.ys.(Array.length t.xs - 1) else ( let index = - Array.binary_search t.xs `First_greater_than_or_equal_to x ~compare:Float.compare + match + Array.binary_search + t.xs + `First_greater_than_or_equal_to + x + ~compare:Float.compare + with + | Some index -> index + | None -> failwith "linear interpolation failed to check bounds" in - let index = Option.value_local_exn index in let prev_x, prev_y = t.xs.(index - 1), t.ys.(index - 1) in let next_x, next_y = t.xs.(index), t.ys.(index) in ((prev_y *. (next_x -. x)) +. (next_y *. (x -. prev_x))) /. (next_x -. prev_x)) diff --git a/src/vision/efficientnet.ml b/src/vision/efficientnet.ml index 7714dd1..bdc8265 100644 --- a/src/vision/efficientnet.ml +++ b/src/vision/efficientnet.ml @@ -40,7 +40,6 @@ let block_args () = type params = { width : float ; depth : float - ; res : int ; dropout : float } @@ -180,38 +179,38 @@ let efficientnet ?(num_classes = 1000) vs params = |> Tensor.adaptive_avg_pool2d ~output_size:[ 1; 1 ] |> Tensor.squeeze_dim ~dim:(-1) |> Tensor.squeeze_dim ~dim:(-1) - |> Tensor.dropout ~p:0.2 ~is_training + |> Tensor.dropout ~p:params.dropout ~is_training |> Layer.forward fc) ;; let b0 ?num_classes vs = - efficientnet ?num_classes vs { width = 1.0; depth = 1.0; res = 224; dropout = 0.2 } + efficientnet ?num_classes vs { width = 1.0; depth = 1.0; dropout = 0.2 } ;; let b1 ?num_classes vs = - efficientnet ?num_classes vs { width = 1.0; depth = 1.1; res = 240; dropout = 0.2 } + efficientnet ?num_classes vs { width = 1.0; depth = 1.1; dropout = 0.2 } ;; let b2 ?num_classes vs = - efficientnet ?num_classes vs { width = 1.1; depth = 1.2; res = 260; dropout = 0.3 } + efficientnet ?num_classes vs { width = 1.1; depth = 1.2; dropout = 0.3 } ;; let b3 ?num_classes vs = - efficientnet ?num_classes vs { width = 1.2; depth = 1.4; res = 300; dropout = 0.3 } + efficientnet ?num_classes vs { width = 1.2; depth = 1.4; dropout = 0.3 } ;; let b4 ?num_classes vs = - efficientnet ?num_classes vs { width = 1.4; depth = 1.8; res = 380; dropout = 0.4 } + efficientnet ?num_classes vs { width = 1.4; depth = 1.8; dropout = 0.4 } ;; let b5 ?num_classes vs = - efficientnet ?num_classes vs { width = 1.6; depth = 2.2; res = 456; dropout = 0.4 } + efficientnet ?num_classes vs { width = 1.6; depth = 2.2; dropout = 0.4 } ;; let b6 ?num_classes vs = - efficientnet ?num_classes vs { width = 1.8; depth = 2.6; res = 528; dropout = 0.5 } + efficientnet ?num_classes vs { width = 1.8; depth = 2.6; dropout = 0.5 } ;; let b7 ?num_classes vs = - efficientnet ?num_classes vs { width = 2.0; depth = 3.1; res = 600; dropout = 0.5 } + efficientnet ?num_classes vs { width = 2.0; depth = 3.1; dropout = 0.5 } ;; diff --git a/src/wrapper/dune b/src/wrapper/dune index 796dea6..6d21edb 100644 --- a/src/wrapper/dune +++ b/src/wrapper/dune @@ -6,15 +6,17 @@ -std=c++17 -fPIC (:include cxx_flags.sexp))) - (foreign_stubs - (language c) - (names torch_stubs torch_stubs_generated)) (name torch_core) (public_name torch.core) (c_library_flags :standard -lstdc++) (flags :standard (:include flags.sexp)) + (foreign_stubs + (language c) + (names torch_stubs torch_stubs_generated) + (flags :standard -Wno-discarded-qualifiers -Wno-incompatible-pointer-types) + (extra_deps torch_api_generated.h torch_api_generated.cpp)) (libraries ctypes.foreign ctypes torch_bindings) (preprocess (pps ppx_jane)))