diff --git a/src/ejabberd_s2s.erl b/src/ejabberd_s2s.erl index cb4e5e5ecfc..6b9f3a03fb4 100644 --- a/src/ejabberd_s2s.erl +++ b/src/ejabberd_s2s.erl @@ -427,7 +427,7 @@ start_connection(From, To, Opts) -> MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts); true -> - %% We choose a connexion from the pool of opened ones. + %% We choose a connection from the pool of opened ones. {ok, choose_connection(From, L)} end end. diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index 48a650a4ee8..a2f9064f564 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -41,7 +41,7 @@ -export([handle_unexpected_info/2, handle_unexpected_cast/2, reject_unauthenticated_packet/2, process_closed/2]). %% API --export([stop/1, close/1, close/2, send/2, update_state/2, establish/1, +-export([stop/1, close/1, close/2, send/2, call/3, reply/2, update_state/2, establish/1, host_up/1, host_down/1]). -include("ejabberd.hrl"). @@ -96,6 +96,13 @@ establish(State) -> update_state(Ref, Callback) -> xmpp_stream_in:cast(Ref, {update_state, Callback}). +-spec call(pid(), term(), non_neg_integer() | infinity) -> term(). +call(Ref, Msg, Timeout) -> + xmpp_stream_in:call(Ref, Msg, Timeout). + +reply(Ref, Reply) -> + xmpp_stream_in:reply(Ref, Reply). + -spec host_up(binary()) -> ok. host_up(Host) -> ejabberd_hooks:add(s2s_in_closed, Host, ?MODULE, @@ -169,15 +176,17 @@ handle_stream_start(_StreamStart, #{lserver := LServer} = State) -> send(State, xmpp:serr_host_unknown()); true -> ServerHost = ejabberd_router:host_of_route(LServer), - State#{server_host => ServerHost} + State1 = State#{server_host => ServerHost}, + ejabberd_hooks:run_fold(s2s_in_stream_started, ServerHost, State1, []) end. handle_stream_end(Reason, #{server_host := LServer} = State) -> State1 = State#{stop_reason => Reason}, ejabberd_hooks:run_fold(s2s_in_closed, LServer, State1, [Reason]). - + handle_stream_established(State) -> - set_idle_timeout(State#{established => true}). + UniqueId = p1_time_compat:unique_integer(), % for xep-0198 s2s + set_idle_timeout(State#{established => true, unique_id => UniqueId}). handle_auth_success(RServer, Mech, _AuthModule, #{sockmod := SockMod, @@ -286,7 +295,7 @@ handle_info(Info, #{server_host := LServer} = State) -> ejabberd_hooks:run_fold(s2s_in_handle_info, LServer, State, [Info]). terminate(Reason, #{auth_domains := AuthDomains, - sockmod := SockMod, socket := Socket} = State) -> + sockmod := SockMod, socket := Socket, server_host := LServer} = State) -> case maps:get(stop_reason, State, undefined) of {tls, _} = Err -> ?ERROR_MSG("(~s) Failed to secure inbound s2s connection: ~s", @@ -302,7 +311,8 @@ terminate(Reason, #{auth_domains := AuthDomains, end, ok, AuthDomains); _ -> ok - end. + end, + ejabberd_hooks:run_fold(s2s_in_terminate, LServer, State, [Reason]). code_change(_OldVsn, State, _Extra) -> {ok, State}. diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl index fea5d81625c..2280a247d4f 100644 --- a/src/ejabberd_s2s_out.erl +++ b/src/ejabberd_s2s_out.erl @@ -32,7 +32,8 @@ handle_auth_success/2, handle_auth_failure/3, handle_packet/2, handle_stream_end/2, handle_stream_downgraded/2, handle_recv/3, handle_send/3, handle_cdata/2, - handle_stream_established/1, handle_timeout/1]). + handle_stream_established/1, handle_timeout/1, + handle_authenticated_features/2]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). %% Hooks @@ -210,6 +211,10 @@ dns_retries(#{server := LServer}) -> dns_timeout(#{server := LServer}) -> ejabberd_config:get_option({s2s_dns_timeout, LServer}, timer:seconds(10)). +handle_authenticated_features(StreamFeatures, #{server_host := ServerHost} = State) -> + ejabberd_hooks:run_fold(s2s_out_authenticated_features, + ServerHost, State, [StreamFeatures]). + handle_auth_success(Mech, #{sockmod := SockMod, socket := Socket, ip := IP, remote_server := RServer, @@ -242,8 +247,9 @@ handle_stream_end(Reason, #{server_host := ServerHost} = State) -> handle_stream_downgraded(StreamStart, #{server_host := ServerHost} = State) -> ejabberd_hooks:run_fold(s2s_out_downgraded, ServerHost, State, [StreamStart]). -handle_stream_established(State) -> - State1 = State#{on_route => send}, +handle_stream_established(#{server_host := ServerHost} = State) -> + State0 = ejabberd_hooks:run_fold(s2s_out_established, ServerHost, State, []), + State1 = State0#{on_route => send}, State2 = resend_queue(State1), set_idle_timeout(State2). @@ -306,7 +312,7 @@ handle_info({route, Pkt}, #{queue := Q, on_route := Action} = State) -> handle_info(Info, #{server_host := ServerHost} = State) -> ejabberd_hooks:run_fold(s2s_out_handle_info, ServerHost, State, [Info]). -terminate(Reason, #{server := LServer, +terminate(Reason, #{server := LServer, server_host := ServerHost, remote_server := RServer} = State) -> ejabberd_s2s:remove_connection({LServer, RServer}, self()), State1 = case Reason of @@ -314,7 +320,8 @@ terminate(Reason, #{server := LServer, _ -> State#{stop_reason => internal_failure} end, bounce_queue(State1), - bounce_message_queue(State1). + bounce_message_queue(State1), + ejabberd_hooks:run_fold(s2s_out_terminate, ServerHost, State1, [Reason]). code_change(_OldVsn, State, _Extra) -> {ok, State}. diff --git a/src/mod_stream_mgmt.erl b/src/mod_stream_mgmt.erl index 2f6b0fc716d..bc1e1473622 100644 --- a/src/mod_stream_mgmt.erl +++ b/src/mod_stream_mgmt.erl @@ -35,6 +35,9 @@ c2s_handle_recv/3]). %% adjust pending session timeout -export([get_resume_timeout/1, set_resume_timeout/2]). +%% API (used by mod_stream_mgmt_s2s) +-export ([mgmt_queue_drop/2, mgmt_queue_add/2, cancel_ack_timer/1, + update_num_stanzas_in/2]). -include("xmpp.hrl"). -include("logger.hrl"). diff --git a/src/mod_stream_mgmt_s2s.erl b/src/mod_stream_mgmt_s2s.erl new file mode 100644 index 00000000000..18edd0c87bb --- /dev/null +++ b/src/mod_stream_mgmt_s2s.erl @@ -0,0 +1,810 @@ +-module(mod_stream_mgmt_s2s). +-behaviour(gen_mod). +-author('amuhar3@gmail.com'). +-protocol({xep, 198, '1.5.2'}). + +%% gen_mod API +-export([start/2, stop/1, reload/3, depends/2, mod_opt_type/1]). +%% client part hooks +-export([s2s_out_stream_init/2, s2s_out_stream_features/2, + s2s_out_packet/2, s2s_out_handle_recv/3, s2s_out_handle_send/3, + s2s_out_handle_info/2, s2s_out_closed/2, + s2s_out_terminate/2, s2s_out_established/1]). +%% server part hooks +-export([s2s_in_stream_started/1, s2s_in_stream_features/2, + s2s_in_unauthenticated_packet/2, s2s_in_authenticated_packet/2, + s2s_in_handle_info/2, s2s_in_closed/2, s2s_in_terminate/2]). + +-include("xmpp.hrl"). +-include("logger.hrl"). +-include("p1_queue.hrl"). + +-define(is_sm_packet(Pkt), + is_record(Pkt, sm_enable) or + is_record(Pkt, sm_enabled) or + is_record(Pkt, sm_resume) or + is_record(Pkt, sm_resumed) or + is_record(Pkt, sm_a) or + is_record(Pkt, sm_r)). + +-record(s2s, {fromto = {<<"">>, <<"">>} :: {binary(), binary()} | '_', + pid = self() :: pid() | '_' | '$1'}). + +-type state() :: map(). + +%%%============================================================================= +%%% API +%%%============================================================================= +start(Host, _Opts) -> + ejabberd_hooks:add(s2s_out_init, Host, ?MODULE, s2s_out_stream_init, 50), + % ejabberd_hooks:add(s2s_out_authenticated_features, + % Host, ?MODULE, s2s_out_stream_features, 50), + ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE, s2s_out_packet, 50), + ejabberd_hooks:add(s2s_out_handle_recv, + Host, ?MODULE, s2s_out_handle_recv, 50), + ejabberd_hooks:add(s2s_out_handle_send, + Host, ?MODULE, s2s_out_handle_send, 50), + ejabberd_hooks:add(s2s_out_handle_info, + Host, ?MODULE, s2s_out_handle_info, 50), + ejabberd_hooks:add(s2s_out_closed, + Host, ?MODULE, s2s_out_closed, 50), + ejabberd_hooks:add(s2s_out_terminate, + Host, ?MODULE, s2s_out_terminate, 50), + ejabberd_hooks:add(s2s_out_established, + Host, ?MODULE, s2s_out_established, 50), + %% server part + ets_cache:new(sm_s2s), + ejabberd_hooks:add(s2s_in_stream_started, + Host, ?MODULE, s2s_in_stream_started, 50), + ejabberd_hooks:add(s2s_in_post_auth_features, + Host, ?MODULE, s2s_in_stream_features, 50), + ejabberd_hooks:add(s2s_in_unauthenticated_packet, + Host, ?MODULE, s2s_in_unauthenticated_packet, 50), + ejabberd_hooks:add(s2s_in_authenticated_packet, + Host, ?MODULE, s2s_in_authenticated_packet, 50), + ejabberd_hooks:add(s2s_in_handle_info, + Host, ?MODULE, s2s_in_handle_info, 50), + ejabberd_hooks:add(s2s_in_closed, + Host, ?MODULE, s2s_in_closed, 50), + ejabberd_hooks:add(s2s_in_terminate, + Host, ?MODULE, s2s_in_terminate, 50). + +stop(Host) -> + ejabberd_hooks:delete(s2s_out_init, Host, ?MODULE, s2s_out_stream_init, 50), + % ejabberd_hooks:delete(s2s_out_authenticated_features, + % Host, ?MODULE, s2s_out_stream_features, 50), + ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE, s2s_out_packet, 50), + ejabberd_hooks:delete(s2s_out_handle_recv, + Host, ?MODULE, s2s_out_handle_recv, 50), + ejabberd_hooks:delete(s2s_out_handle_send, + Host, ?MODULE, s2s_out_handle_send, 50), + ejabberd_hooks:delete(s2s_out_handle_info, + Host, ?MODULE, s2s_out_handle_info, 50), + ejabberd_hooks:delete(s2s_out_closed, + Host, ?MODULE, s2s_out_closed, 50), + ejabberd_hooks:delete(s2s_out_terminate, + Host, ?MODULE, s2s_out_terminate, 50), + ejabberd_hooks:delete(s2s_out_established, + Host, ?MODULE, s2s_out_established, 50), + %% server part + ejabberd_hooks:delete(s2s_in_stream_started, + Host, ?MODULE, s2s_in_stream_started, 50), + ejabberd_hooks:delete(s2s_in_post_auth_features, + Host, ?MODULE, s2s_in_stream_features, 50), + ejabberd_hooks:delete(s2s_in_unauthenticated_packet, + Host, ?MODULE, s2s_in_unauthenticated_packet, 50), + ejabberd_hooks:delete(s2s_in_authenticated_packet, + Host, ?MODULE, s2s_in_authenticated_packet, 50), + ejabberd_hooks:delete(s2s_in_handle_info, + Host, ?MODULE, s2s_in_handle_info, 50), + ejabberd_hooks:delete(s2s_in_closed, + Host, ?MODULE, s2s_in_closed, 50), + ejabberd_hooks:delete(s2s_in_terminate, + Host, ?MODULE, s2s_in_terminate, 50). + +reload(_Host, _NewOpts, _OldOpts) -> + ?WARNING_MSG("module ~s is reloaded, but new configuration will take " + "effect for newly created s2s connections only", [?MODULE]). + +depends(_Host, _Opts) -> []. + +%% client part +s2s_out_stream_init({ok, #{server_host := ServerHost} = State}, Opts) -> + State1 = State#{mgmt_timeout => get_resume_timeout(ServerHost), + mgmt_queue_type => get_queue_type(ServerHost), + mgmt_max_queue => get_max_ack_queue(ServerHost), + mgmt_ack_timeout => get_ack_timeout(ServerHost), + mgmt_connection_timeout => get_connection_timeout(ServerHost), + mgmt_stanzas_out => 0, + mgmt_stanzas_req => 0}, + case proplists:get_value(resume, Opts) of + OldState when OldState /= undefined -> + {ok, State1#{mgmt_state => connecting, mgmt_prev_session => OldState}}; + _ -> + {ok, State1#{mgmt_state => inactive}} + end; +s2s_out_stream_init(Acc, _Opts) -> + Acc. + +s2s_out_stream_features(#{mgmt_timeout := Timeout, + mgmt_queue_type := QueueType} = State, + #stream_features{sub_els = SubEls}) -> + case check_stream_mgmt_support(SubEls) of + Xmlns when Xmlns == ?NS_STREAM_MGMT_2; Xmlns == ?NS_STREAM_MGMT_3 -> + case State of + #{mgmt_prev_session := OldState} -> + #{mgmt_previd := Id} = OldState, + State1 = State#{mgmt_state => pending, + mgmt_queue => p1_queue:new(QueueType), + mgmt_xmlns => Xmlns}, + send(State1, #sm_resume{h = 0, + xmlns = Xmlns, + previd = Id}); + _ -> + Res = if Timeout > 0 -> + #sm_enable{xmlns = Xmlns, + resume = true, + max = Timeout}; + true -> + #sm_enable{xmlns = Xmlns} + end, + State1 = State#{mgmt_state => wait_for_enabled, + mgmt_queue => p1_queue:new(QueueType), + mgmt_xmlns => Xmlns}, + send(State1, Res) + end; + _ -> + State + end; +s2s_out_stream_features(State, _) -> + State. + +s2s_out_established(#{mgmt_state := connecting, + mgmt_queue_type := QueueType, + mgmt_prev_session := OldState} = State) -> + #{mgmt_previd := Id} = OldState, + State1 = State#{mgmt_state => pending, + mgmt_queue => p1_queue:new(QueueType), + mgmt_xmlns => ?NS_STREAM_MGMT_3}, + send(State1, #sm_resume{h = 0, + xmlns = ?NS_STREAM_MGMT_3, + previd = Id}); +s2s_out_established(#{mgmt_timeout := Timeout, + mgmt_queue_type := QueueType} = State) -> + Xmlns = ?NS_STREAM_MGMT_3, + Res = + if Timeout > 0 -> + #sm_enable{xmlns = Xmlns, + resume = true, + max = Timeout}; + true -> + #sm_enable{xmlns = Xmlns} + end, + + State1 = State#{mgmt_state => wait_for_enabled, + mgmt_queue => p1_queue:new(QueueType), + mgmt_xmlns => Xmlns}, + send(State1, Res); +s2s_out_established(State) -> + State. + +s2s_out_packet(#{mgmt_state := pending} = State, #sm_resumed{} = Pkt) -> + {stop, handle_resumed(Pkt, State)}; +s2s_out_packet(#{mgmt_state := MgmtState} = State, Pkt) + when ?is_sm_packet(Pkt) -> + if MgmtState == active -> + {stop, perform_stream_mgmt(Pkt, State)}; + MgmtState == wait_for_enabled -> + {stop, negotiate_stream_mgmt(Pkt, State)}; + true -> + {stop, State} + end; +s2s_out_packet(State, _Pkt) -> + State. + +s2s_out_handle_send(#{mgmt_state := MgmtState, + lang := Lang} = State, Pkt, SendResult) + when MgmtState == active; + MgmtState == wait_for_enabled; + MgmtState == pending -> + case Pkt of + _ when ?is_stanza(Pkt) -> + Meta = xmpp:get_meta(Pkt), + case maps:get(mgmt_is_resent, Meta, false) of + false -> + case mod_stream_mgmt:mgmt_queue_add(State, Pkt) of + #{mgmt_max_queue := exceeded} = State1 -> + Err = xmpp:serr_policy_violation( + <<"Too many unacked stanzas">>, Lang), + send(State1, Err); + State1 when MgmtState == active, SendResult == ok -> + send_rack(State1); + State1 -> + State1 + end; + true -> + State + end; + _ -> + State + end; +s2s_out_handle_send(State, _, _) -> + State. + +s2s_out_handle_info(#{mgmt_ack_timer := TRef, remote_server := RServer, + mod := Mod} = State, {timeout, TRef, ack_timeout}) -> + ?DEBUG("Timed out waiting for stream management " + "acknowledgement of ~s", [RServer]), + Mod:stop(State); +s2s_out_handle_info(#{mgmt_state := connecting, + remote_server := RServer, mod := Mod} = State, + {timeout, _TRef, connection_timeout}) -> + ?DEBUG("Timed out waiting for connection " + "establishment for resumption previous session with ~s", [RServer]), + Mod:stop(State#{mgmt_state => timeout}); +s2s_out_handle_info(State, _) -> + State. + +s2s_out_handle_recv(#{mgmt_state := wait_for_enabled, + remote_server := RServer} = State, _El, #sm_failed{}) -> + ?DEBUG("Remote server ~s can't enable stream management", [RServer]), + State#{mgmt_state => inactive}; +s2s_out_handle_recv(#{mgmt_state := pending, + remote_server := RServer, + mgmt_timeout := Timeout, + mgmt_xmlns := Xmlns, + mgmt_queue := Queue, + mgmt_prev_session := OldState} = State, + _El, #sm_failed{h = H}) -> + ?DEBUG("Remote server ~s can't resume previous session", [RServer]), + #{mgmt_queue := OldQueue, + mgmt_stanzas_out := OldNumStanzasOut} = + case H of + undefined-> + OldState; + _ -> + check_h_attribute(OldState, H) + end, + {NewNumStanzasOut, NewQueue} = + p1_queue:foldl( + fun({_, Time, Pkt}, {AccNum, AccQueue}) -> + Num = AccNum + 1, + {Num, p1_queue:in({Num, Time, Pkt}, AccQueue)} + end, {OldNumStanzasOut, OldQueue}, Queue), + State1 = State#{mgmt_state => wait_for_enabled, + mgmt_queue => NewQueue, + mgmt_stanzas_out => NewNumStanzasOut}, + State2 = maps:remove(mgmt_prev_session, State1), + State3 = + if Timeout > 0 -> + send(State2, #sm_enable{xmlns = Xmlns, + resume = true, + max = Timeout}); + true -> + send(State2, #sm_enable{xmlns = Xmlns}) + end, + resend_unacked_stanzas(State3, OldNumStanzasOut); +s2s_out_handle_recv(State, _El, _Pkt) -> + State. + +s2s_out_closed(#{mgmt_state := connecting} = State, _) -> + {stop, transition_to_connecting(State#{stream_state => connecting})}; +s2s_out_closed(State, _) -> + State. + +s2s_out_terminate(#{mgmt_state := active, mgmt_queue := Queue} = State, _Reason) + when ?qlen(Queue) > 0 -> + transition_to_connecting(State); +s2s_out_terminate(#{mgmt_state := timeout, + mgmt_prev_session := OldState} = State, _Reason) -> + #{mgmt_queue := Queue} = OldState, + bounce_errors(State, Queue), + State; +s2s_out_terminate(#{mgmt_state := pending, + mgmt_prev_session := OldState} = State, _Reason) -> + #{mgmt_queue := Queue} = OldState, + route_unacked_stanzas(State, Queue), + State; +s2s_out_terminate(State, _Reason) -> + State. + +%% server part + +s2s_in_stream_started(#{server_host := Host} = State) -> + Timeout = get_resume_timeout(Host), + MaxTimeout = get_max_resume_timeout(Host, Timeout), + State#{mgmt_state => inactive, + mgmt_timeout => Timeout, + mgmt_max_timeout => MaxTimeout, + mgmt_stanzas_in => 0}; +s2s_in_stream_started(State) -> + State. + +s2s_in_stream_features(Acc, _Host) -> + [#feature_sm{xmlns = ?NS_STREAM_MGMT_2}, + #feature_sm{xmlns = ?NS_STREAM_MGMT_3}| Acc]. + +s2s_in_unauthenticated_packet(State, Pkt) when ?is_sm_packet(Pkt) -> + Err = #sm_failed{reason = 'unexpected-request', + xmlns = ?NS_STREAM_MGMT_3}, + {stop, send(State, Err)}; +s2s_in_unauthenticated_packet(State, _Pkt) -> + State. + +s2s_in_authenticated_packet(#{mgmt_state := inactive} = State, + #sm_resume{} = Pkt) -> + {stop, handle_resume(Pkt, State)}; +s2s_in_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt) + when ?is_sm_packet(Pkt) -> + if MgmtState == active; MgmtState == pending -> + {stop, perform_stream_mgmt(Pkt, State)}; + true -> + {stop, negotiate_stream_mgmt(Pkt, State)} + end; +s2s_in_authenticated_packet(State, Pkt) -> + mod_stream_mgmt:update_num_stanzas_in(State, Pkt). + +s2s_in_handle_info(#{mgmt_state := pending, remote_server := RServer, + mod := Mod} = State, {timeout, _, pending_timeout}) -> + ?DEBUG("Timed out waiting for resumption of stream for ~s", [RServer]), + Mod:stop(State); +s2s_in_handle_info(#{mgmt_state := pending, mod:= Mod, + remote_server := RServer, unique_id := UniqueId} = State, + {_, From, {resume_session, RServer, UniqueId}}) -> + Mod:reply(From, {resume, State}), + {stop, State#{mgmt_state => resumed}}; +s2s_in_handle_info(#{mod := Mod} = State, {_, From, {resume_session, _, _}}) -> + Mod:reply(From, error), + {stop, State}; +s2s_in_handle_info(State, _Msg) -> + State. + +s2s_in_closed(#{mgmt_state := active} = State, _Reason) -> + {stop, transition_to_pending(State)}; +s2s_in_closed(State, _Reason) -> + State. + +s2s_in_terminate(#{mgmt_state := resumed, + remote_server := RServer} = State, _Reason) -> + ?INFO_MSG("Closing former stream of resumed session for ~s", [RServer]), + State; +s2s_in_terminate(#{mgmt_state := pending, + server_host := Server, + mgmt_stanzas_in := H} = State, _Reason) -> + ResumeId = make_resume_id(State), + ets_cache:insert_new(sm_s2s, {Server, ResumeId}, H), + State; +s2s_in_terminate(State, _Reason) -> + State. + +%%%============================================================================= +%%% Internal functions +%%%============================================================================= + +-spec check_stream_mgmt_support(Els :: [xmlel()]) -> binary(). +check_stream_mgmt_support(Els) -> + check_stream_mgmt_support(Els, <<>>). + +-spec check_stream_mgmt_support(Els :: [xmlel()], + Res :: binary()) -> binary(). +check_stream_mgmt_support([El | Els], Res) -> + case El of + #xmlel{name = <<"sm">>, attrs = Attrs} -> + case fxml:get_attr(<<"xmlns">>, Attrs) of + {value, ?NS_STREAM_MGMT_3} -> + ?NS_STREAM_MGMT_3; + {value, ?NS_STREAM_MGMT_2} -> + check_stream_mgmt_support(Els, ?NS_STREAM_MGMT_2); + _ -> + check_stream_mgmt_support(Els, Res) + end; + _ -> + check_stream_mgmt_support(Els, Res) + end; +check_stream_mgmt_support([], Res) -> Res. + +-spec negotiate_stream_mgmt(xmpp_element(), state()) -> state(). +negotiate_stream_mgmt(Pkt, State) -> + {ServerPart, Xmlns} = + case State of + #{mgmt_xmlns := MgmtXmlns} -> + {false, MgmtXmlns}; + _ -> + {true, xmpp:get_ns(Pkt)} + end, + case Pkt of + #sm_enable{} when ServerPart -> + handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt); + #sm_enabled{} when not ServerPart -> + handle_enabled(State, Pkt); + _ when is_record(Pkt, sm_a); + is_record(Pkt, sm_r); + (is_record(Pkt, sm_resume) andalso ServerPart); + (is_record(Pkt, sm_resumed) andalso not ServerPart) -> + Err = #sm_failed{reason = 'unexpected-request', xmlns = Xmlns}, + send(State, Err); + _ -> + Err = #sm_failed{reason = 'bad-request', xmlns = Xmlns}, + send(State, Err) + end. + +-spec perform_stream_mgmt(xmpp_element(), state()) -> state(). +perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) -> + case xmpp:get_ns(Pkt) of + Xmlns -> + case Pkt of + #sm_a{} -> + handle_a(State, Pkt); + #sm_r{} -> + handle_r(State); + _ -> + send(State, #sm_failed{reason = 'bad-request', xmlns = Xmlns}) + end; + _ -> + send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns}) + end. + +handle_enable(#{remote_server := RServer, + mgmt_timeout := DefaultTimeout, + mgmt_max_timeout := MaxTimeout, + mgmt_xmlns := Xmlns} = State, + #sm_enable{resume = Resume, max = Max}) -> + Timeout = + if Resume == false -> + 0; + Max /= undefined, Max > 0, Max =< MaxTimeout -> + Max; + true -> + DefaultTimeout + end, + State1 = State#{mgmt_state => active, + mgmt_timeout => Timeout}, + Res = + if Timeout > 0 -> + ?INFO_MSG("Stream management with " + "resumption enabled for ~s", [RServer]), + #sm_enabled{resume = true, id = make_resume_id(State), + max = Timeout, xmlns = Xmlns}; + true -> + ?INFO_MSG("Stream management enabled for ~s", [RServer]), + #sm_enabled{xmlns = Xmlns} + end, + send(State1, Res). + +-spec handle_enabled(state(), sm_enabled()) -> state(). +handle_enabled(#{remote_server := RServer, + mgmt_timeout := DefaultTimeout, + mgmt_queue := Queue} = State, + #sm_enabled{resume = Resume, max = Max, id = Id}) -> + Timeout = if Resume == false -> + 0; + Max /= undefined -> + Max; + true -> + DefaultTimeout + end, + State1 = if Timeout > 0 -> + ?INFO_MSG("Stream management with " + "resumption enabled for ~s", [RServer]), + State#{mgmt_state => active, + mgmt_previd => Id, + mgmt_timeout => Timeout}; + true -> + ?INFO_MSG("Stream management enabled for ~s", [RServer]), + State#{mgmt_state => active, mgmt_timeout => Timeout} + end, + + case not p1_queue:is_empty(Queue) of + true -> + send_rack(State1); + _ -> + State1 + end. + +-spec handle_r(state()) -> state(). +handle_r(#{mgmt_stanzas_in := H, + mgmt_xmlns := Xmlns} = State) -> + send(State, #sm_a{h = H, xmlns = Xmlns}). + +-spec handle_a(state(), sm_a()) -> state(). +handle_a(State, #sm_a{h = H}) -> + State1 = check_h_attribute(State, H), + resend_rack(State1). + +-spec make_resume_id(state()) -> binary(). +make_resume_id(#{remote_server := RServer, unique_id := UniqueId}) -> + misc:term_to_base64({RServer, UniqueId}). + +-spec transition_to_connecting(state()) -> state(). +transition_to_connecting(#{mgmt_state := active, + mgmt_queue := Queue, + mgmt_timeout := 0} = State) -> + route_unacked_stanzas(State, Queue), + State; +transition_to_connecting(#{mgmt_state := active, + server_host := Server, + remote_server := RServer, + mgmt_connection_timeout := Timeout} = State) -> + State1 = mod_stream_mgmt:cancel_ack_timer(State), + ?DEBUG("Try to connect to remote server ~s", [RServer]), + case resume(Server, RServer, [{resume, State1}]) of + {ok, Pid} -> + erlang:start_timer(timer:seconds(Timeout), Pid, connection_timeout), + State1; + _ -> + State1 + end; +transition_to_connecting(#{mgmt_state := connecting, mod := Mod} = State) -> + Mod:connect(self()), + State; +transition_to_connecting(State) -> + State. + +-spec transition_to_pending(state()) -> state(). +transition_to_pending(#{mgmt_state := active, mod := Mod, + mgmt_timeout := 0} = State) -> + Mod:stop(State); +transition_to_pending(#{mgmt_state := active, + remote_server := RServer, + mgmt_timeout := Timeout} = State) -> + ?INFO_MSG("Waiting for resumption of stream for ~s", [RServer]), + erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout), + State#{mgmt_state => pending}; +transition_to_pending(State) -> + State. + +resume(From, To, Opts) -> + {ok, Pid} = ejabberd_s2s_out:start(From, To, Opts), + F = fun() -> + mnesia:write(#s2s{fromto = {From, To}, pid = Pid}), + Pid + end, + TRes = mnesia:transaction(F), + case TRes of + {atomic, Pid} -> + ejabberd_s2s_out:connect(Pid), + {ok, Pid}; + {aborted, _Reason} -> + ejabberd_s2s_out:stop(Pid), + error + end. + +get_old_session_state(#{server_host := Server}, ResumeId, []) -> + case ets_cache:lookup(sm_s2s, {Server, ResumeId}) of + {ok, H} -> + {error, <<"Previous session timed out">>, H}; + _ -> + {error, <<"Previous session PID not found">>} + end; +get_old_session_state(State, ResumeId, [{_, Pid, _,_}|Specs]) + when is_pid(Pid), Pid /= self() -> + case misc:base64_to_term(ResumeId) of + {term, {RServer, UniqueId}} -> + try gen_fsm:sync_send_all_state_event(Pid, + {resume_session, RServer, UniqueId}) of + {resume, OldState} -> + ejabberd_s2s_in:stop(Pid), + {ok, OldState}; + error -> + get_old_session_state(State, ResumeId, Specs) + catch + _:_ -> + get_old_session_state(State, ResumeId, Specs) + end; + _ -> + {error, <<"Invalid 'previd' value">>} + end; +get_old_session_state(State, ResumeId, [_H|L]) -> + get_old_session_state(State, ResumeId, L). + +handle_resume(#sm_resume{previd = ResumeId, xmlns = Xmlns}, + #{remote_server := RServer, lang := Lang} = State) -> + Res = case get_old_session_state(State, ResumeId, + supervisor:which_children(ejabberd_s2s_in_sup)) of + {ok, OldState} -> + {ok, OldState}; + {error, Err, InH} -> + {error, #sm_failed{reason = 'item-not-found', + text = xmpp:mk_text(Err, Lang), + h = InH, xmlns = Xmlns}, Err}; + {error, Err} -> + {error, #sm_failed{reason = 'item-not-found', + text = xmpp:mk_text(Err, Lang), + xmlns = Xmlns}, Err} + end, + case Res of + {ok, OldSessionState} -> + #{mgmt_stanzas_in := H, + mgmt_timeout := Timeout, + mgmt_xmlns := AttrXmlns, + unique_id := UniqueId} = OldSessionState, + State1 = State#{mgmt_state => active, + mgmt_stanzas_in => H, + mgmt_timeout => Timeout, + mgmt_xmlns => AttrXmlns, + unique_id => UniqueId}, + + State2 = send(State1, #sm_resumed{previd = ResumeId, + h = H, + xmlns = AttrXmlns}), + ?INFO_MSG("Resumed session for ~s", [RServer]), + State2; + {error, El, Msg} -> + ?INFO_MSG("Cannot resume session for ~s: ~s", [RServer, Msg]), + send(State, El) + end. + +-spec handle_resumed(sm_resumed(), state()) -> state(). +handle_resumed(#sm_resumed{h = H, previd = _Id}, + #{remote_server := RServer, + mgmt_queue := Queue, + mgmt_prev_session := OldState} = State) -> + ResumedState = copy_state(OldState, State), + #{mgmt_xmlns := Xmlns, + mgmt_queue := OldQueue, + mgmt_stanzas_out := OldNumStanzasOut} = ResumedState, + State1 = check_h_attribute(ResumedState, H), + {NewNumStanzasOut, NewQueue} = + p1_queue:foldl( + fun({_, Time, Pkt}, {AccNum, AccQueue}) -> + Num = AccNum + 1, + {Num, p1_queue:in({Num, Time, Pkt}, AccQueue)} + end, {OldNumStanzasOut, OldQueue}, Queue), + State2 = State1#{mgmt_state => active, + mgmt_queue => NewQueue, + mgmt_stanzas_out => NewNumStanzasOut}, + State3 = resend_unacked_stanzas(State2, OldNumStanzasOut), + ?DEBUG("Resumed session for ~s", [RServer]), + send(State3, #sm_r{xmlns = Xmlns}). + +resend_unacked_stanzas(#{remote_server := RServer, + mgmt_queue := Queue} = State, LastStanzaNum) + when ?qlen(Queue) > 0 -> + ?DEBUG("Resending ~B unacknowledged stanza(s) to ~s", + [p1_queue:len(Queue), RServer]), + p1_queue:foldl( + fun({Num, Time, Pkt}, AccState) when Num =< LastStanzaNum -> + NewPkt = add_resent_delay_info(AccState, Pkt, Time), + send(AccState, xmpp:put_meta(NewPkt, mgmt_is_resent, true)) + end, State, Queue); +resend_unacked_stanzas(State, _) -> + State. + +route_unacked_stanzas(#{remote_server := RServer} = State, Queue) + when ?qlen(Queue) > 0 -> + ?DEBUG("Re-rout ~B unacknowledged stanza(s) to ~s", + [p1_queue:len(Queue), RServer]), + p1_queue:foreach( + fun({_, Time, Pkt}) -> + NewPkt = add_resent_delay_info(State, Pkt, Time), + ejabberd_router:route(xmpp:put_meta(NewPkt, mgmt_is_resent, true)) + end, Queue); +route_unacked_stanzas(_State, _Queue) -> + ok. + +bounce_errors(_State, Queue) + when ?qlen(Queue) > 0 -> + p1_queue:foreach( + fun({_, _, Pkt}) -> + Error = xmpp:err_remote_server_timeout(), + ejabberd_router:route_error(Pkt, Error) + end, Queue); +bounce_errors(_State, _) -> + ok. + +-spec copy_state(state(), state()) -> state(). +copy_state(#{mgmt_xmlns := Xmlns, + mgmt_queue := Queue, + mgmt_stanzas_out := NumStanzasOut, + mgmt_previd := Id} = _OldState, + #{mgmt_queue_type := QueueType} = NewState) -> + Queue1 = case QueueType of + ram -> p1_queue:file_to_ram(Queue); + _ -> p1_queue:ram_to_file(Queue) + end, + NewState#{mgmt_xmlns => Xmlns, + mgmt_queue => Queue1, + mgmt_stanzas_out => NumStanzasOut, + mgmt_previd => Id}. + +-spec check_h_attribute(state(), non_neg_integer()) -> state(). +check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, + remote_server := RServer} = State, H) + when H > NumStanzasOut -> + ?DEBUG("~s acknowledged ~B stanzas," + "but only ~B were sent ", [RServer, H, NumStanzasOut]), + mod_stream_mgmt:mgmt_queue_drop(State#{mgmt_stanzas_out => H}, NumStanzasOut); +check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, + remote_server := RServer} = State, H) -> + ?DEBUG("~s acknowledged ~B of ~B " + "stanzas", [RServer, H, NumStanzasOut]), + mod_stream_mgmt:mgmt_queue_drop(State, H). + +-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza(). +add_resent_delay_info(#{server_host := LServer}, El, Time) -> + xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>); +add_resent_delay_info(_State, El, _Time) -> + El. + +-spec send(state(), xmpp_element()) -> state(). +send(#{mod := Mod} = State, Pkt) -> + Mod:send(State, Pkt). + +send_rack(#{mgmt_ack_timer := _} = State) -> + State; +send_rack(#{mgmt_xmlns := Xmlns, + mgmt_stanzas_out := NumStanzasOut, + mgmt_ack_timeout := infinity} = State) -> + State1 = State#{mgmt_stanzas_req => NumStanzasOut}, + send(State1, #sm_r{xmlns = Xmlns}); +send_rack(#{mgmt_xmlns := Xmlns, + mgmt_stanzas_out := NumStanzasOut, + mgmt_ack_timeout := AckTimeout} = State) -> + TRef = erlang:start_timer(AckTimeout, self(), ack_timeout), + State1 = State#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}, + send(State1, #sm_r{xmlns = Xmlns}). + +resend_rack(#{mgmt_ack_timer := _, + mgmt_queue := Queue, + mgmt_stanzas_out := NumStanzasOut, + mgmt_stanzas_req := NumStanzasReq} = State) -> + State1 = mod_stream_mgmt:cancel_ack_timer(State), + case NumStanzasReq < NumStanzasOut andalso not p1_queue:is_empty(Queue) of + true -> send_rack(State1); + false -> State1 + end; +resend_rack(State) -> + State. + +%%%============================================================================= +%%% Configuration processing +%%%============================================================================= + +get_resume_timeout(Host) -> + gen_mod:get_module_opt(Host, ?MODULE, resume_timeout, 300). + +get_max_resume_timeout(Host, ResumeTimeout) -> + case gen_mod:get_module_opt(Host, ?MODULE, max_resume_timeout) of + undefined -> ResumeTimeout; + Max when Max >= ResumeTimeout -> Max; + _ -> ResumeTimeout + end. + +get_queue_type(Host) -> + case gen_mod:get_module_opt(Host, ?MODULE, queue_type) of + undefined -> ejabberd_config:default_queue_type(Host); + Type -> Type + end. + +get_max_ack_queue(Host) -> + gen_mod:get_module_opt(Host, ?MODULE, max_ack_queue, 1000). + +get_ack_timeout(Host) -> + case gen_mod:get_module_opt(Host, ?MODULE, ack_timeout, 60) of + infinity -> infinity; + T -> timer:seconds(T) + end. + +get_connection_timeout(Host) -> + gen_mod:get_module_opt(Host, ?MODULE, connection_timeout, 300). + +mod_opt_type(connection_timeout) -> + fun(I) when is_integer(I), I >= 0 -> I end; +mod_opt_type(max_ack_queue) -> + fun(I) when is_integer(I), I > 0 -> I; + (infinity) -> infinity + end; +mod_opt_type(ack_timeout) -> + fun(I) when is_integer(I), I > 0 -> I; + (infinity) -> infinity + end; +mod_opt_type(resume_timeout) -> + fun(I) when is_integer(I), I >= 0 -> I end; +mod_opt_type(max_resume_timeout) -> + fun(I) when is_integer(I), I >= 0 -> I end; +mod_opt_type(queue_type) -> + fun(file) -> file; + (ram) -> ram + end; +mod_opt_type(_) -> [max_ack_queue, ack_timeout, resume_timeout, max_resume_timeout, + queue_type, connection_timeout]. diff --git a/test/ejabberd_SUITE.erl b/test/ejabberd_SUITE.erl index 97c56159ac6..55a7b000f33 100644 --- a/test/ejabberd_SUITE.erl +++ b/test/ejabberd_SUITE.erl @@ -320,10 +320,10 @@ init_per_testcase(TestCase, OrigConfig) -> Password = ?config(password, Config), ejabberd_auth:try_register(User, Server, Password), open_session(bind(auth(connect(Config)))); - _ when TestGroup == s2s_tests -> + _ when TestGroup == s2s_tests; TestGroup == sms2s_single -> auth(connect(starttls(connect(Config)))); - _ -> - open_session(bind(auth(connect(Config)))) + _ -> + open_session(bind(auth(connect(Config)))) end. end_per_testcase(_TestCase, _Config) -> @@ -520,7 +520,8 @@ s2s_tests() -> test_missing_to, test_invalid_from, bad_nonza, - codec_failure]}]. + codec_failure]}, + sms2s_tests:single_cases()]. groups() -> [{ldap, [sequence], ldap_tests()}, diff --git a/test/ejabberd_SUITE_data/ejabberd.yml b/test/ejabberd_SUITE_data/ejabberd.yml index a648cb4223c..0577b70ca27 100644 --- a/test/ejabberd_SUITE_data/ejabberd.yml +++ b/test/ejabberd_SUITE_data/ejabberd.yml @@ -488,6 +488,8 @@ Welcome to this XMPP server." mod_stream_mgmt: max_ack_queue: 10 resume_timeout: 3 + mod_stream_mgmt_s2s: + resume_timeout: 3 mod_time: [] mod_version: [] registration_timeout: infinity diff --git a/test/sms2s_tests.erl b/test/sms2s_tests.erl new file mode 100644 index 00000000000..8599ac24167 --- /dev/null +++ b/test/sms2s_tests.erl @@ -0,0 +1,85 @@ +%%%------------------------------------------------------------------- +%%% Author : Anna Mukharram +%%%------------------------------------------------------------------- + +-module(sms2s_tests). + +%% API +-compile(export_all). + +-import(suite, [connect/1, send/2, recv/1, set_opt/3, + close_socket/1, disconnect/1]). + +-include("suite.hrl"). + +%%%=================================================================== +%%% API +%%%=================================================================== + +single_cases() -> + {sms2s_single, [sequence], + [single_test(enable), + single_test(resume), + single_test(resume_failed)]}. + +enable(Config) -> + Server = ?config(server, Config), + ServerJID = jid:make(<<"">>, Server, <<"">>), + From = ?config(stream_from, Config), + FromJID = jid:make(<<"">>, From, <<"">>), + Msg = #message{from = FromJID, to = ServerJID, type = headline, + body = [#text{data = <<"body">>}]}, + ct:comment("Stream management with resumption is enabled"), + send(Config, #sm_enable{resume = true, xmlns = ?NS_STREAM_MGMT_3}), + #sm_enabled{id = ID, resume = true} = recv(Config), + ct:comment("Initial request; 'h' should be 0"), + send(Config, #sm_r{xmlns = ?NS_STREAM_MGMT_3}), + #sm_a{h = 0} = recv(Config), + ct:comment("Sending three messages and requesting again; 'h' should be 3"), + send(Config, Msg), + send(Config, Msg), + send(Config, Msg), + send(Config, #sm_r{xmlns = ?NS_STREAM_MGMT_3}), + #sm_a{h = 3} = recv(Config), + ct:comment("Closing socket"), + close_socket(Config), + {save_config, set_opt(sm_previd, ID, Config)}. + +resume(Config) -> + {_, SMConfig} = ?config(saved_config, Config), + ID = ?config(sm_previd, SMConfig), + ct:comment("Resuming the session"), + send(Config, #sm_resume{previd = ID, h = 0, xmlns = ?NS_STREAM_MGMT_3}), + #sm_resumed{previd = ID, h = 3} = recv(Config), + ct:comment("Checking if the server counts stanzas correctly"), + Server = ?config(server, Config), + ServerJID = jid:make(<<"">>, Server, <<"">>), + From = ?config(stream_from, Config), + FromJID = jid:make(<<"">>, From, <<"">>), + Msg = #message{from = FromJID, to = ServerJID, type = headline, + body = [#text{data = <<"body">>}]}, + send(Config, Msg), + send(Config, #sm_r{xmlns = ?NS_STREAM_MGMT_3}), + #sm_a{h = 4} = recv(Config), + ct:comment("Closing socket"), + close_socket(Config), + {save_config, set_opt(sm_previd, ID, Config)}. + +resume_failed(Config) -> + {_, SMConfig} = ?config(saved_config, Config), + ID = ?config(sm_previd, SMConfig), + ct:comment("Waiting for the session to time out"), + ct:sleep(30000), + ct:comment("Trying to resume timed out session"), + send(Config, #sm_resume{previd = ID, h = 0, xmlns = ?NS_STREAM_MGMT_3}), + #sm_failed{reason = 'item-not-found', h = 4} = recv(Config), + disconnect(Config). + + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +single_test(T) -> + list_to_atom("sms2s_" ++ atom_to_list(T)). + +