Skip to content

Commit

Permalink
v0.18~preview.130.05+548
Browse files Browse the repository at this point in the history
  • Loading branch information
public-release committed Nov 21, 2024
1 parent e4d20de commit 6015e9f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 28 deletions.
1 change: 1 addition & 0 deletions examples/translation/dune
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
(modes byte exe)
(names seq2seq)
(libraries stdio torch unix)
(flags :standard -w -32)
(preprocess
(pps ppx_jane)))
19 changes: 6 additions & 13 deletions src/gen_bindings/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ module Func = struct
type arg =
{ arg_name : string
; arg_type : arg_type
; default_value : string option
; is_const : bool
}

Expand Down Expand Up @@ -481,17 +480,13 @@ let read_yaml filename =
Map.find arg "is_nullable"
|> Option.value_map ~default:false ~f:extract_bool
in
let default_value =
Map.find arg "default" |> Option.map ~f:extract_string
in
let has_default_value = Map.find arg "default" |> Option.is_some in
match Func.arg_type_of_string arg_type ~is_nullable with
| Some Scalar when Option.is_some default_value && not is_nullable -> None
| Some Scalar when has_default_value && not is_nullable -> None
| Some TensorOptions
when Option.is_some default_value && Set.mem no_tensor_options name ->
None
| Some arg_type -> Some { Func.arg_name; arg_type; default_value; is_const }
| None ->
if Option.is_some default_value then None else raise Not_a_simple_arg)
when has_default_value && Set.mem no_tensor_options name -> None
| Some arg_type -> Some { Func.arg_name; arg_type; is_const }
| None -> if has_default_value then None else raise Not_a_simple_arg)
in
Some { Func.name; operator_name; overload_name; args; returns; kind }
with
Expand Down Expand Up @@ -690,9 +685,7 @@ let methods =
; kind = `method_
}
in
let ca arg_name arg_type =
{ Func.arg_name; arg_type; default_value = None; is_const = true }
in
let ca arg_name arg_type = { Func.arg_name; arg_type; is_const = true } in
[ c "grad" [ ca "self" Tensor ]
; c "set_requires_grad" [ ca "self" Tensor; ca "r" Bool ]
; c "toType" [ ca "self" Tensor; ca "scalar_type" ScalarType ]
Expand Down
11 changes: 9 additions & 2 deletions src/torch/optimizer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,16 @@ module Linear_interpolation = struct
then t.ys.(Array.length t.xs - 1)
else (
let index =
Array.binary_search t.xs `First_greater_than_or_equal_to x ~compare:Float.compare
match
Array.binary_search
t.xs
`First_greater_than_or_equal_to
x
~compare:Float.compare
with
| Some index -> index
| None -> failwith "linear interpolation failed to check bounds"
in
let index = Option.value_local_exn index in
let prev_x, prev_y = t.xs.(index - 1), t.ys.(index - 1) in
let next_x, next_y = t.xs.(index), t.ys.(index) in
((prev_y *. (next_x -. x)) +. (next_y *. (x -. prev_x))) /. (next_x -. prev_x))
Expand Down
19 changes: 9 additions & 10 deletions src/vision/efficientnet.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ let block_args () =
type params =
{ width : float
; depth : float
; res : int
; dropout : float
}

Expand Down Expand Up @@ -180,38 +179,38 @@ let efficientnet ?(num_classes = 1000) vs params =
|> Tensor.adaptive_avg_pool2d ~output_size:[ 1; 1 ]
|> Tensor.squeeze_dim ~dim:(-1)
|> Tensor.squeeze_dim ~dim:(-1)
|> Tensor.dropout ~p:0.2 ~is_training
|> Tensor.dropout ~p:params.dropout ~is_training
|> Layer.forward fc)
;;

let b0 ?num_classes vs =
efficientnet ?num_classes vs { width = 1.0; depth = 1.0; res = 224; dropout = 0.2 }
efficientnet ?num_classes vs { width = 1.0; depth = 1.0; dropout = 0.2 }
;;

let b1 ?num_classes vs =
efficientnet ?num_classes vs { width = 1.0; depth = 1.1; res = 240; dropout = 0.2 }
efficientnet ?num_classes vs { width = 1.0; depth = 1.1; dropout = 0.2 }
;;

let b2 ?num_classes vs =
efficientnet ?num_classes vs { width = 1.1; depth = 1.2; res = 260; dropout = 0.3 }
efficientnet ?num_classes vs { width = 1.1; depth = 1.2; dropout = 0.3 }
;;

let b3 ?num_classes vs =
efficientnet ?num_classes vs { width = 1.2; depth = 1.4; res = 300; dropout = 0.3 }
efficientnet ?num_classes vs { width = 1.2; depth = 1.4; dropout = 0.3 }
;;

let b4 ?num_classes vs =
efficientnet ?num_classes vs { width = 1.4; depth = 1.8; res = 380; dropout = 0.4 }
efficientnet ?num_classes vs { width = 1.4; depth = 1.8; dropout = 0.4 }
;;

let b5 ?num_classes vs =
efficientnet ?num_classes vs { width = 1.6; depth = 2.2; res = 456; dropout = 0.4 }
efficientnet ?num_classes vs { width = 1.6; depth = 2.2; dropout = 0.4 }
;;

let b6 ?num_classes vs =
efficientnet ?num_classes vs { width = 1.8; depth = 2.6; res = 528; dropout = 0.5 }
efficientnet ?num_classes vs { width = 1.8; depth = 2.6; dropout = 0.5 }
;;

let b7 ?num_classes vs =
efficientnet ?num_classes vs { width = 2.0; depth = 3.1; res = 600; dropout = 0.5 }
efficientnet ?num_classes vs { width = 2.0; depth = 3.1; dropout = 0.5 }
;;
8 changes: 5 additions & 3 deletions src/wrapper/dune
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
-std=c++17
-fPIC
(:include cxx_flags.sexp)))
(foreign_stubs
(language c)
(names torch_stubs torch_stubs_generated))
(name torch_core)
(public_name torch.core)
(c_library_flags :standard -lstdc++)
(flags
:standard
(:include flags.sexp))
(foreign_stubs
(language c)
(names torch_stubs torch_stubs_generated)
(flags :standard -Wno-discarded-qualifiers -Wno-incompatible-pointer-types)
(extra_deps torch_api_generated.h torch_api_generated.cpp))
(libraries ctypes.foreign ctypes torch_bindings)
(preprocess
(pps ppx_jane)))
Expand Down

0 comments on commit 6015e9f

Please sign in to comment.