Skip to content

Commit

Permalink
Merge pull request #7091 from rabbitmq/mqtt-max-size-connect-packet
Browse files Browse the repository at this point in the history
Set MQTT max packet size
  • Loading branch information
michaelklishin authored Jan 29, 2023
2 parents d8da0b5 + 02cf072 commit 5dc2d9b
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 53 deletions.
4 changes: 4 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Revert "Format MQTT code with erlfmt"
209f23fa2f58e0240116b3e8e5be9cd54d34b569
# Format MQTT code with erlfmt
1de9fcf582def91d1cee6bea457dd24e8a53a431
125 changes: 79 additions & 46 deletions deps/rabbitmq_mqtt/src/rabbit_mqtt_packet.erl
Original file line number Diff line number Diff line change
Expand Up @@ -10,86 +10,66 @@
-include("rabbit_mqtt_packet.hrl").
-include("rabbit_mqtt.hrl").

-export([parse/2, initial_state/0, serialise/2]).
-export([init_state/0, reset_state/0,
parse/2, serialise/2]).
-export_type([state/0]).

-opaque state() :: none | fun().
-opaque state() :: unauthenticated | authenticated | fun().

-define(RESERVED, 0).
-define(MAX_LEN, 16#fffffff).
-define(HIGHBIT, 2#10000000).
-define(LOWBITS, 2#01111111).
-define(MAX_MULTIPLIER, ?HIGHBIT * ?HIGHBIT * ?HIGHBIT).
-define(MAX_PACKET_SIZE_CONNECT, 65_536).

-spec initial_state() -> state().
initial_state() -> none.
-spec init_state() -> state().
init_state() -> unauthenticated.

-spec reset_state() -> state().
reset_state() -> authenticated.

-spec parse(binary(), state()) ->
{more, state()} |
{ok, mqtt_packet(), binary()} |
{error, any()}.
parse(<<>>, none) ->
{more, fun(Bin) -> parse(Bin, none) end};
parse(<<MessageType:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, none) ->
parse(<<>>, authenticated) ->
{more, fun(Bin) -> parse(Bin, authenticated) end};
parse(<<MessageType:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, authenticated) ->
parse_remaining_len(Rest, #mqtt_packet_fixed{ type = MessageType,
dup = bool(Dup),
qos = QoS,
retain = bool(Retain) });
parse(Bin, Cont) -> Cont(Bin).
parse(<<?CONNECT:4, 0:4, Rest/binary>>, unauthenticated) ->
parse_remaining_len(Rest, #mqtt_packet_fixed{type = ?CONNECT});
parse(Bin, Cont)
when is_function(Cont) ->
Cont(Bin).

parse_remaining_len(<<>>, Fixed) ->
{more, fun(Bin) -> parse_remaining_len(Bin, Fixed) end};
parse_remaining_len(Rest, Fixed) ->
parse_remaining_len(Rest, Fixed, 1, 0).

parse_remaining_len(_Bin, _Fixed, Multiplier, _Length)
when Multiplier > ?MAX_MULTIPLIER ->
{error, malformed_remaining_length};
parse_remaining_len(_Bin, _Fixed, _Multiplier, Length)
when Length > ?MAX_LEN ->
{error, invalid_mqtt_packet_len};
{error, invalid_mqtt_packet_length};
parse_remaining_len(<<>>, Fixed, Multiplier, Length) ->
{more, fun(Bin) -> parse_remaining_len(Bin, Fixed, Multiplier, Length) end};
parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Fixed, Multiplier, Value) ->
parse_remaining_len(Rest, Fixed, Multiplier * ?HIGHBIT, Value + Len * Multiplier);
parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Fixed, Multiplier, Value) ->
parse_packet(Rest, Fixed, Value + Len * Multiplier).

parse_packet(Bin, #mqtt_packet_fixed{ type = Type,
qos = Qos } = Fixed, Length)
parse_packet(Bin, #mqtt_packet_fixed{type = ?CONNECT} = Fixed, Length) ->
parse_connect(Bin, Fixed, Length);
parse_packet(Bin, #mqtt_packet_fixed{type = Type,
qos = Qos} = Fixed, Length)
when Length =< ?MAX_LEN ->
case {Type, Bin} of
{?CONNECT, <<PacketBin:Length/binary, Rest/binary>>} ->
{ProtoName, Rest1} = parse_utf(PacketBin),
<<ProtoVersion : 8, Rest2/binary>> = Rest1,
<<UsernameFlag : 1,
PasswordFlag : 1,
WillRetain : 1,
WillQos : 2,
WillFlag : 1,
CleanSession : 1,
_Reserved : 1,
KeepAlive : 16/big,
Rest3/binary>> = Rest2,
{ClientId, Rest4} = parse_utf(Rest3),
{WillTopic, Rest5} = parse_utf(Rest4, WillFlag),
{WillMsg, Rest6} = parse_msg(Rest5, WillFlag),
{UserName, Rest7} = parse_utf(Rest6, UsernameFlag),
{PasssWord, <<>>} = parse_utf(Rest7, PasswordFlag),
case protocol_name_approved(ProtoVersion, ProtoName) of
true ->
wrap(Fixed,
#mqtt_packet_connect{
proto_ver = ProtoVersion,
will_retain = bool(WillRetain),
will_qos = WillQos,
will_flag = bool(WillFlag),
clean_sess = bool(CleanSession),
keep_alive = KeepAlive,
client_id = ClientId,
will_topic = WillTopic,
will_msg = WillMsg,
username = UserName,
password = PasssWord}, Rest);
false ->
{error, protocol_header_corrupt}
end;
{?PUBLISH, <<PacketBin:Length/binary, Rest/binary>>} ->
{TopicName, Rest1} = parse_utf(PacketBin),
{PacketId, Payload} = case Qos of
Expand Down Expand Up @@ -122,6 +102,59 @@ parse_packet(Bin, #mqtt_packet_fixed{ type = Type,
end}
end.

parse_connect(Bin, Fixed, Length) ->
MaxSize = application:get_env(?APP_NAME,
max_packet_size_unauthenticated,
?MAX_PACKET_SIZE_CONNECT),
case Length =< MaxSize of
true ->
case Bin of
<<PacketBin:Length/binary, Rest/binary>> ->
{ProtoName, Rest1} = parse_utf(PacketBin),
<<ProtoVersion : 8, Rest2/binary>> = Rest1,
<<UsernameFlag : 1,
PasswordFlag : 1,
WillRetain : 1,
WillQos : 2,
WillFlag : 1,
CleanSession : 1,
_Reserved : 1,
KeepAlive : 16/big,
Rest3/binary>> = Rest2,
{ClientId, Rest4} = parse_utf(Rest3),
{WillTopic, Rest5} = parse_utf(Rest4, WillFlag),
{WillMsg, Rest6} = parse_msg(Rest5, WillFlag),
{UserName, Rest7} = parse_utf(Rest6, UsernameFlag),
{PasssWord, <<>>} = parse_utf(Rest7, PasswordFlag),
case protocol_name_approved(ProtoVersion, ProtoName) of
true ->
wrap(Fixed,
#mqtt_packet_connect{
proto_ver = ProtoVersion,
will_retain = bool(WillRetain),
will_qos = WillQos,
will_flag = bool(WillFlag),
clean_sess = bool(CleanSession),
keep_alive = KeepAlive,
client_id = ClientId,
will_topic = WillTopic,
will_msg = WillMsg,
username = UserName,
password = PasssWord}, Rest);
false ->
{error, protocol_header_corrupt}
end;
TooShortBin
when byte_size(TooShortBin) < Length ->
{more, fun(BinMore) ->
parse_connect(<<TooShortBin/binary, BinMore/binary>>,
Fixed, Length)
end}
end;
false ->
{error, connect_packet_too_large}
end.

parse_topics(_, <<>>, Topics) ->
Topics;
parse_topics(?SUBSCRIBE = Sub, Bin, Topics) ->
Expand Down
4 changes: 2 additions & 2 deletions deps/rabbitmq_mqtt/src/rabbit_mqtt_reader.erl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ init(Ref) ->
connection_state = running,
received_connect_packet = false,
conserve = false,
parse_state = rabbit_mqtt_packet:initial_state(),
parse_state = rabbit_mqtt_packet:init_state(),
proc_state = ProcessorState},
State1 = control_throttle(State0),
State = rabbit_event:init_stats_timer(State1, #state.stats_timer),
Expand Down Expand Up @@ -336,7 +336,7 @@ process_received_bytes(Bytes,
{ok, ProcState1} ->
process_received_bytes(
Rest,
State #state{parse_state = rabbit_mqtt_packet:initial_state(),
State #state{parse_state = rabbit_mqtt_packet:reset_state(),
proc_state = ProcState1});
%% PUBLISH and more
{error, unauthorized = Reason, ProcState1} ->
Expand Down
29 changes: 29 additions & 0 deletions deps/rabbitmq_mqtt/test/shared_SUITE.erl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ subgroups() ->
,clean_session_kill_node
,rabbit_status_connection_count
,trace
,max_packet_size_unauthenticated
]}
]},
{cluster_size_3, [],
Expand Down Expand Up @@ -1337,6 +1338,34 @@ trace(Config) ->
delete_queue(Ch, TraceQ),
[ok = emqtt:disconnect(C) || C <- [Pub, Sub]].

max_packet_size_unauthenticated(Config) ->
App = rabbitmq_mqtt,
Par = ClientId = ?FUNCTION_NAME,
Opts = [{will_topic, <<"will/topic">>}],

{C1, Connect} = util:start_client(
ClientId, Config, 0,
[{will_payload, binary:copy(<<"a">>, 64_000)} | Opts]),
?assertMatch({ok, _}, Connect(C1)),
ok = emqtt:disconnect(C1),

MaxSize = 500,
ok = rpc(Config, application, set_env, [App, Par, MaxSize]),

{C2, Connect} = util:start_client(
ClientId, Config, 0,
[{will_payload, binary:copy(<<"b">>, MaxSize + 1)} | Opts]),
true = unlink(C2),
?assertMatch({error, _}, Connect(C2)),

{C3, Connect} = util:start_client(
ClientId, Config, 0,
[{will_payload, binary:copy(<<"c">>, round(MaxSize / 2))} | Opts]),
?assertMatch({ok, _}, Connect(C3)),
ok = emqtt:disconnect(C3),

ok = rpc(Config, application, unset_env, [App, Par]).

%% -------------------------------------------------------------------
%% Internal helpers
%% -------------------------------------------------------------------
Expand Down
9 changes: 7 additions & 2 deletions deps/rabbitmq_mqtt/test/util.erl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
connect/2,
connect/3,
connect/4,
start_client/4,
get_events/1,
assert_event_type/2,
assert_event_prop/2,
Expand Down Expand Up @@ -119,6 +120,11 @@ connect(ClientId, Config, AdditionalOpts) ->
connect(ClientId, Config, 0, AdditionalOpts).

connect(ClientId, Config, Node, AdditionalOpts) ->
{C, Connect} = start_client(ClientId, Config, Node, AdditionalOpts),
{ok, _Properties} = Connect(C),
C.

start_client(ClientId, Config, Node, AdditionalOpts) ->
{Port, WsOpts, Connect} =
case rabbit_ct_helpers:get_config(Config, websocket, false) of
false ->
Expand All @@ -136,5 +142,4 @@ connect(ClientId, Config, Node, AdditionalOpts) ->
{clientid, rabbit_data_coercion:to_binary(ClientId)}
] ++ WsOpts ++ AdditionalOpts,
{ok, C} = emqtt:start_link(Options),
{ok, _Properties} = Connect(C),
C.
{C, Connect}.
6 changes: 3 additions & 3 deletions deps/rabbitmq_web_mqtt/src/rabbit_web_mqtt_handler.erl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

-record(state, {
socket :: {rabbit_proxy_socket, any(), any()} | rabbit_net:socket(),
parse_state = rabbit_mqtt_packet:initial_state() :: rabbit_mqtt_packet:state(),
parse_state = rabbit_mqtt_packet:init_state() :: rabbit_mqtt_packet:state(),
proc_state :: undefined | rabbit_mqtt_processor:state(),
connection_state = running :: running | blocked,
conserve = false :: boolean(),
Expand Down Expand Up @@ -273,7 +273,7 @@ handle_data1(Data, State = #state{ parse_state = ParseState,
{ok, ProcState1} ->
handle_data1(
Rest,
State#state{parse_state = rabbit_mqtt_packet:initial_state(),
State#state{parse_state = rabbit_mqtt_packet:reset_state(),
proc_state = ProcState1});
{error, Reason, _} ->
stop_mqtt_protocol_error(State, Reason, ConnName);
Expand All @@ -296,7 +296,7 @@ parse(Data, ParseState) ->
end.

stop_mqtt_protocol_error(State, Reason, ConnName) ->
?LOG_INFO("MQTT protocol error ~tp for connection ~tp", [Reason, ConnName]),
?LOG_WARNING("Web MQTT protocol error ~tp for connection ~tp", [Reason, ConnName]),
stop(State, ?CLOSE_PROTOCOL_ERROR, Reason).

stop(State) ->
Expand Down

0 comments on commit 5dc2d9b

Please sign in to comment.