diff --git a/.flake8 b/.flake8 index 9401b5a7..78dc53a5 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,9 @@ [flake8] max-line-length = 150 -# ignore = E125, E123, E251 + +# ignore this flake8 warnings since they are not pep8 compliant. +# see `black` documentation +ignore = E203, E266, E501, W503 exclude = # No need to traverse our git directory .git, diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 1ff5c7d0..ed3e6343 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -23,7 +23,7 @@ apt-get -y install pandoc && pip install -e .[dev,test,benchmarks] - add tests - format your code using [black](https://github.com/ambv/black): ```shell -black --line-length=100 --py36 . +black --line-length=100 --target-version py36 . ``` - run [flake8](https://pypi.python.org/pypi/flake8) on the code and fix any issues: ```shell diff --git a/CHANGELOG.md b/CHANGELOG.md index b46b326d..d608bd29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ most recent version is listed first. +## **version:** v0.6.6 +- make sure that `naz` reads exactly the first 4bytes of an smpp header: https://github.com/komuw/naz/pull/153 + - if `naz` is unable to read exactly those bytes, it unbinds and closes the connection + - this is so as to ensure that `naz` behaves correctly and does not enter into an inconsistent state. +- make sire that `naz` reads exacly the first 16bytes of the smpp header: https://github.com/komuw/naz/pull/155 + - this builds on the [earlier work](https://github.com/komuw/naz/pull/153) but now `naz` takes it a step further and will unbind & close connection if it is unable to read the entire SMPP header + - this is done to prevent inconsistency and also to try and be faithful to the smpp spec. + + ## **version:** v0.6.5 - Simplify Breach log handler: https://github.com/komuw/naz/pull/152 diff --git a/naz/__version__.py b/naz/__version__.py index 2ab0d33b..a5ecd021 100644 --- a/naz/__version__.py +++ b/naz/__version__.py @@ -2,7 +2,7 @@ "__title__": "naz", "__description__": "Naz is an async SMPP client.", "__url__": "https://github.com/komuw/naz", - "__version__": "v0.6.5", + "__version__": "v0.6.6", "__author__": "komuW", "__author_email__": "komuw05@gmail.com", "__license__": "MIT", diff --git a/naz/client.py b/naz/client.py index fe8f3d44..b05b40a2 100644 --- a/naz/client.py +++ b/naz/client.py @@ -358,6 +358,7 @@ def __init__( self.naz_message_protocol_version = "1" self.current_session_state = SmppSessionState.CLOSED + self._header_pdu_length = 16 self.drain_duration = drain_duration self.socket_timeout = socket_timeout @@ -848,7 +849,7 @@ async def tranceiver_bind(self, log_id: str = "") -> None: ) # header - command_length = 16 + len(body) # 16 is for headers + command_length = self._header_pdu_length + len(body) # 16 is for headers command_id = self.command_ids[smpp_command] # the status for success see section 5.1.3 command_status = SmppCommandStatus.ESME_ROK.value @@ -961,7 +962,7 @@ async def enquire_link(self, TESTING: bool = False) -> typing.Union[None, bytes] body = b"" # header - command_length = 16 + len(body) # 16 is for headers + command_length = self._header_pdu_length + len(body) # 16 is for headers command_id = self.command_ids[smpp_command] command_status = 0x00000000 # not used for `enquire_link` try: @@ -1047,7 +1048,7 @@ async def enquire_link_resp(self, sequence_number: int) -> None: body = b"" # header - command_length = 16 + len(body) # 16 is for headers + command_length = self._header_pdu_length + len(body) # 16 is for headers command_id = self.command_ids[smpp_command] command_status = SmppCommandStatus.ESME_ROK.value sequence_number = sequence_number @@ -1106,7 +1107,7 @@ async def unbind_resp(self, sequence_number: int) -> None: body = b"" # header - command_length = 16 + len(body) # 16 is for headers + command_length = self._header_pdu_length + len(body) # 16 is for headers command_id = self.command_ids[smpp_command] command_status = SmppCommandStatus.ESME_ROK.value sequence_number = sequence_number @@ -1153,7 +1154,7 @@ async def deliver_sm_resp(self, sequence_number: int) -> None: ) # header - command_length = 16 + len(body) # 16 is for headers + command_length = self._header_pdu_length + len(body) # 16 is for headers command_id = self.command_ids[smpp_command] command_status = SmppCommandStatus.ESME_ROK.value sequence_number = sequence_number @@ -1346,7 +1347,7 @@ async def _build_submit_sm_pdu( ) # header - command_length = 16 + len(body) # 16 is for headers + command_length = self._header_pdu_length + len(body) # 16 is for headers command_id = self.command_ids[smpp_command] # the status for success see section 5.1.3 command_status = 0x00000000 # not used for `submit_sm` @@ -1842,22 +1843,34 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[None, bytes] ) return None - command_length_header_data = b"" + header_data = b"" try: if typing.TYPE_CHECKING: # make mypy happy; https://github.com/python/mypy/issues/4805 assert isinstance(self.reader, asyncio.streams.StreamReader) - # todo: look at `pause_reading` and `resume_reading` methods # `client.reader` and `client.writer` should not have timeouts since they are non-blocking # https://github.com/komuw/naz/issues/116 - command_length_header_data = await self.reader.read(4) - - # make sure the header data be 4 bytes - while len(command_length_header_data) != 4: - more_bytes = await self.reader.read(4 - len(command_length_header_data)) - command_length_header_data = command_length_header_data + more_bytes - + header_data = await self.reader.readexactly(self._header_pdu_length) + except asyncio.IncompleteReadError as e: + # see: https://github.com/komuw/naz/issues/135 + self._log( + logging.ERROR, + { + "event": "naz.Client.receive_data", + "stage": "end", + "state": "unable to read exactly {0}bytes of smpp header.".format( + self._header_pdu_length + ), + "error": str(e), + }, + ) + header_data == b"" + # close connection. it will be automatically reconnected later + await self._unbind_and_disconnect() + if TESTING: + # offer escape hatch for tests to come out of endless loop + return header_data except ( ConnectionError, TimeoutError, @@ -1877,7 +1890,7 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[None, bytes] }, ) - if command_length_header_data == b"": + if header_data == b"": retry_count += 1 poll_read_interval = self._retry_after(retry_count) self._log( @@ -1899,8 +1912,9 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[None, bytes] # we didn't fail to read from SMSC retry_count = 0 - total_pdu_length = struct.unpack(">I", command_length_header_data)[0] - MSGLEN = total_pdu_length - 4 + # first 4bytes of header are the command_length + total_pdu_length = struct.unpack(">I", header_data[:4])[0] + MSGLEN = total_pdu_length - self._header_pdu_length chunks = [] bytes_recd = 0 while bytes_recd < MSGLEN: @@ -1953,7 +1967,7 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[None, bytes] chunks.append(chunk) bytes_recd = bytes_recd + len(chunk) - full_pdu_data = command_length_header_data + b"".join(chunks) + full_pdu_data = header_data + b"".join(chunks) self._log( logging.DEBUG, { @@ -1982,8 +1996,8 @@ async def _parse_response_pdu(self, pdu: bytes) -> None: {"event": "naz.Client._parse_response_pdu", "stage": "start", "pdu": log_pdu}, ) - header_data = pdu[:16] - body_data = pdu[16:] + header_data = pdu[: self._header_pdu_length] + body_data = pdu[self._header_pdu_length :] command_id_header_data = header_data[4:8] command_status_header_data = header_data[8:12] sequence_number_header_data = header_data[12:16] @@ -2324,7 +2338,7 @@ async def unbind(self) -> None: body = b"" # header - command_length = 16 + len(body) # 16 is for headers + command_length = self._header_pdu_length + len(body) # 16 is for headers command_id = self.command_ids[smpp_command] command_status = 0x00000000 # not used for `unbind` try: diff --git a/tests/test_client.py b/tests/test_client.py index 80b05a6b..de6b024d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -24,29 +24,82 @@ async def mock_coro(*args, **kwargs): return mock_coro +class MockStreamWriter: + """ + This is a mock of python's StreamWriter; + https://docs.python.org/3.6/library/asyncio-stream.html#asyncio.StreamWriter + """ + + def __init__(self, _is_closing=False): + self.transport = self._create_transport(_is_closing=_is_closing) + + async def drain(self): + pass + + def close(self): + # when this is called, we set the transport to be in closed/closing state + self.transport = self._create_transport(_is_closing=True) + + def write(self, data): + pass + + def get_extra_info(self, name, default=None): + # when this is called, we set the transport to be in open state. + # this is because this method is called in `naz.Client.connect` + # so it is the only chance we have of 're-establishing' connection + self.transport = self._create_transport(_is_closing=False) + + def _create_transport(self, _is_closing): + class MockTransport: + def __init__(self, _is_closing): + self._is_closing = _is_closing + + def set_write_buffer_limits(self, n): + pass + + def is_closing(self): + return self._is_closing + + return MockTransport(_is_closing=_is_closing) + + class MockStreamReader: """ This is a mock of python's StreamReader; https://docs.python.org/3.6/library/asyncio-stream.html#asyncio.StreamReader - - We mock the reader having a succesful submit_sm_resp PDU. - For the first read we return the first 4bytes, - the second read, we return the remaining bytes. """ def __init__(self, pdu): self.pdu = pdu - async def read(self, n_index=-1): - if n_index == 0: - return b"" blocks = [] blocks.append(self.pdu) - data = b"".join(blocks) - if n_index == 4: - return data[:n_index] + self.data = b"".join(blocks) + + async def read(self, n=-1): + if n == 0: + return b"" + + if n == -1: + _to_read_data = self.data # read all data + _remaining_data = b"" else: - return data[4:] + _to_read_data = self.data[:n] + _remaining_data = self.data[n:] + + self.data = _remaining_data + return _to_read_data + + async def readexactly(self, n): + _to_read_data = self.data[:n] + _remaining_data = self.data[n:] + + if len(_to_read_data) != n: + # unable to read exactly n bytes + raise asyncio.IncompleteReadError(partial=_to_read_data, expected=n) + + self.data = _remaining_data + return _to_read_data class TestClient(TestCase): @@ -472,11 +525,9 @@ def test_receving_data(self): submit_sm_resp_pdu = ( b"\x00\x00\x00\x12\x80\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x030\x00" ) - - # TODO: also create a MockStreamWriter mock_naz_connect.mock.return_value = ( MockStreamReader(pdu=submit_sm_resp_pdu), - "MockStreamWriter", + MockStreamWriter(), ) reader, writer = self._run(self.cli.connect()) @@ -485,6 +536,27 @@ def test_receving_data(self): received_pdu = self._run(self.cli.receive_data(TESTING=True)) self.assertEqual(received_pdu, submit_sm_resp_pdu) + def test_partial_reads_disconnect(self): + """ + test that if we are unable to read the full 16byte smpp header, + then we should close the connection. + """ + with mock.patch("naz.Client.connect", new=AsyncMock()) as mock_naz_connect, mock.patch( + "naz.Client._unbind_and_disconnect", new=AsyncMock() + ) as mock_naz_unbind_and_disconnect: + submit_sm_resp_pdu = b"\x00\x00\x00" + mock_naz_connect.mock.return_value = ( + MockStreamReader(pdu=submit_sm_resp_pdu), + MockStreamWriter(), + ) + + reader, writer = self._run(self.cli.connect()) + self.cli.reader = reader + self.cli.writer = writer + received_pdu = self._run(self.cli.receive_data(TESTING=True)) + self.assertEqual(received_pdu, b"") + self.assertTrue(mock_naz_unbind_and_disconnect.mock.called) + def test_enquire_link_resp(self): with mock.patch("naz.q.SimpleOutboundQueue.enqueue", new=AsyncMock()) as mock_naz_enqueue: sequence_number = 7