Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safer alternative to Lwt_io.establish_server #258

Merged
merged 3 commits into from
Jun 24, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 56 additions & 19 deletions src/unix/lwt_io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,11 @@ let close : type mode. mode channel -> unit Lwt.t = fun wrapper ->
(fun _ ->
abort wrapper)

let is_closed wrapper =
match wrapper.state with
| Closed -> true
| _ -> false

let flush_all () =
let wrappers = Outputs.fold (fun x l -> x :: l) outputs [] in
Lwt_list.iter_p
Expand Down Expand Up @@ -1346,26 +1351,26 @@ let with_file ?buffer ?flags ?perm ~mode filename f =

let file_length filename = with_file ~mode:input filename length

let close_socket fd =
Lwt.finalize
(fun () ->
Lwt.catch
(fun () ->
Lwt_unix.shutdown fd Unix.SHUTDOWN_ALL;
Lwt.return_unit)
(function
(* Occurs if the peer closes the connection first. *)
| Unix.Unix_error (Unix.ENOTCONN, _, _) -> Lwt.return_unit
| exn -> Lwt.fail exn))
(fun () ->
Lwt_unix.close fd)

let open_connection ?fd ?in_buffer ?out_buffer sockaddr =
let fd = match fd with
| None -> Lwt_unix.socket (Unix.domain_of_sockaddr sockaddr) Unix.SOCK_STREAM 0
| Some fd -> fd
in
let close = lazy begin
Lwt.finalize
(fun () ->
Lwt.catch
(fun () ->
Lwt_unix.shutdown fd Unix.SHUTDOWN_ALL;
Lwt.return_unit)
(function
| Unix.Unix_error(Unix.ENOTCONN, _, _) ->
(* This may happen if the server closed the connection before us *)
Lwt.return_unit
| exn -> Lwt.fail exn))
(fun () ->
Lwt_unix.close fd)
end in
let close = lazy (close_socket fd) in
Lwt.catch
(fun () ->
Lwt_unix.connect fd sockaddr >>= fun () ->
Expand Down Expand Up @@ -1406,10 +1411,7 @@ let establish_server ?fd ?(buffer_size = !default_buffer_size) ?(backlog=5) sock
Lwt.pick [Lwt_unix.accept sock >|= (fun x -> `Accept x); abort_waiter] >>= function
| `Accept(fd, addr) ->
(try Lwt_unix.set_close_on_exec fd with Invalid_argument _ -> ());
let close = lazy begin
Lwt_unix.shutdown fd Unix.SHUTDOWN_ALL;
Lwt_unix.close fd
end in
let close = lazy (close_socket fd) in
f (of_fd ~buffer:(Lwt_bytes.create buffer_size) ~mode:input
~close:(fun () -> Lazy.force close) fd,
of_fd ~buffer:(Lwt_bytes.create buffer_size) ~mode:output
Expand All @@ -1427,6 +1429,41 @@ let establish_server ?fd ?(buffer_size = !default_buffer_size) ?(backlog=5) sock
ignore (loop ());
{ shutdown = lazy(Lwt.wakeup abort_wakener `Shutdown) }

let establish_server_safe ?fd ?buffer_size ?backlog sockaddr f =
let best_effort_close channel =
(* First, check whether the channel is closed. f may have already tried to
close the channel, received an exception, and handled it somehow. If so,
trying to close the channel here will trigger the same exception, which
will go to !Lwt.async_exception_hook, despite the user's efforts. *)
(* The Invalid state is not possible on the channel, because it was not
created using Lwt_io.atomic. *)
if is_closed channel then
Lwt.return_unit
else
Lwt.catch
(fun () -> close channel)
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)
in

let handler ((input_channel, output_channel) as channels) =
Lwt.async (fun () ->
(* Not using Lwt.finalize here, to make sure that exceptions from [f]
reach !Lwt.async_exception_hook before exceptions from closing the
channels. *)
Lwt.catch
(fun () -> f channels)
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)

>>= fun () -> best_effort_close input_channel
>>= fun () -> best_effort_close output_channel)
in

establish_server ?fd ?buffer_size ?backlog sockaddr handler

let ignore_close ch =
ignore (close ch)

Expand Down
45 changes: 39 additions & 6 deletions src/unix/lwt_io.mli
Original file line number Diff line number Diff line change
Expand Up @@ -413,18 +413,51 @@ val with_connection :
type server
(** Type of a server *)

val establish_server_safe :
?fd : Lwt_unix.file_descr ->
?buffer_size : int ->
?backlog : int ->
Unix.sockaddr -> (input_channel * output_channel -> unit Lwt.t) -> server
(** [establish_server_safe ?fd ?buffer_size ?backlog sockaddr f] creates a
server which listens for incoming connections. New connections are passed
to [f]. When threads returned by [f] complete, the connections are closed
automatically.

The server does not wait for each thread. It begins accepting new
connections immediately.

If a thread raises an exception, it is passed to
[!Lwt.async_exception_hook]. Likewise, if the automatic [close] of a
connection raises an exception, it is passed to
[!Lwt.async_exception_hook]. To handle exceptions raised by [close], call
it manually inside [f]. *)

val establish_server :
?fd : Lwt_unix.file_descr ->
?buffer_size : int ->
?backlog : int ->
Unix.sockaddr -> (input_channel * output_channel -> unit) -> server
(** [establish_server ?fd ?buffer_size ?backlog sockaddr f] creates
a server which will listen for incoming connections. New
connections are passed to [f]. Note that [f] must not raise any
exception. If [fd] is not specified, a fresh file descriptor will
be created.
(** [establish_server ?fd ?buffer_size ?backlog sockaddr f] creates a server
which listens for incoming connections. New connections are passed to [f].

[establish_server] does not start separate threads for running [f], nor
close the connections passed to [f]. Thus, the skeleton of a practical
server based on [establish_server] might look like this:

{[
Lwt_io.establish_server address (fun (ic, oc) ->
Lwt.async (fun () ->

(* ... *)

Lwt.catch (fun () -> Lwt_io.close oc) (fun _ -> Lwt.return_unit) >>=
Lwt.catch (fun () -> Lwt_io.close ic) (fun _ -> Lwt.return_unit)))
]}

If [fd] is not specified, a fresh file descriptor will be created for
listening.

[backlog] is the argument passed to [Lwt_unix.listen] *)
[backlog] is the argument passed to [Lwt_unix.listen]. *)

val shutdown_server : server -> unit
(** Shutdown the given server *)
Expand Down
Loading