Skip to content

Commit

Permalink
add Websocket_lwt.check_origin (vbmithr#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
zoggy committed Mar 31, 2017
1 parent 2a5b5d8 commit 8dbf70b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
33 changes: 24 additions & 9 deletions lib/websocket_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,33 @@ let set_tcp_nodelay flow =
| TCP { fd; _ } -> Lwt_unix.setsockopt fd Lwt_unix.TCP_NODELAY true
| _ -> ()

module SSet = Set.Make(String)

let check_origin ?(origin_mandatory=false) ~hosts =
let pred origin_host = SSet.exists
(fun h -> String.Ascii.lowercase h = origin_host)
hosts
in
fun request ->
let headers = request.Cohttp.Request.headers in
match Cohttp.Header.get headers "origin" with
None -> not origin_mandatory
| Some origin ->
let origin = Uri.of_string origin in
match Uri.host origin with
| None -> not origin_mandatory
| Some host -> (* host is already lowercased by Uri *)
pred host

let check_origin_with_host request =
let headers = request.Cohttp.Request.headers in
let host = Cohttp.Header.get headers "host" in
let origin = Cohttp.Header.get headers "origin" in
match host, origin with
| None, _ -> failwith "Missing host header" (* mandatory in http/1.1 *)
| _, None -> true
| Some host, Some origin ->
(* remove port *)
let hostname = Option.value_map ~default:host ~f:fst (String.cut ~sep:":" host) in
let origin = Uri.of_string origin in
Some hostname = Uri.host origin
match host with
| None -> failwith "Missing host header" (* mandatory in http/1.1 *)
| Some host ->
(* remove port *)
let hostname = Option.value_map ~default:host ~f:fst (String.cut ~sep:":" host) in
check_origin ~hosts:(SSet.singleton hostname) request

let with_connection ?(extra_headers = Cohttp.Header.init ())
?(random_string=Rng.std ?state:None) ~ctx client uri =
Expand Down
12 changes: 12 additions & 0 deletions lib/websocket_lwt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ module Connected_client : sig
(** [source t] is the source address of [t]. *)
end

module SSet : Set.S with type elt = string

val check_origin :
?origin_mandatory: bool -> hosts:SSet.t ->
Cohttp.Request.t -> bool
(** [check_origin ~hosts req] with return [true] is the origin header
exists and match one of the provided hostnames.
If origin header is not present of does not container a hostname,
return [not origin_mandatory]. Default value of [origin_mandatory]
is false.
Hostnames in [hosts] are (ascii-)lowercased when compared.*)

val check_origin_with_host : Cohttp.Request.t -> bool
(** [check_origin_with_host] returns false if the origin header exists and its
host doesn't match the host header *)
Expand Down

0 comments on commit 8dbf70b

Please sign in to comment.