Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read exactly the entire 16bytes smpp header #155

Merged
merged 11 commits into from
Aug 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[flake8]
max-line-length = 150
# ignore = E125, E123, E251

# ignore this flake8 warnings since they are not pep8 compliant.
# see `black` documentation
ignore = E203, E266, E501, W503
exclude =
# No need to traverse our git directory
.git,
Expand Down
2 changes: 1 addition & 1 deletion .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ apt-get -y install pandoc && pip install -e .[dev,test,benchmarks]
- add tests
- format your code using [black](https://github.com/ambv/black):
```shell
black --line-length=100 --py36 .
black --line-length=100 --target-version py36 .
```
- run [flake8](https://pypi.python.org/pypi/flake8) on the code and fix any issues:
```shell
Expand Down
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
most recent version is listed first.


## **version:** v0.6.6
- make sure that `naz` reads exactly the first 4bytes of an smpp header: https://github.com/komuw/naz/pull/153
- if `naz` is unable to read exactly those bytes, it unbinds and closes the connection
- this is so as to ensure that `naz` behaves correctly and does not enter into an inconsistent state.
- make sire that `naz` reads exacly the first 16bytes of the smpp header: https://github.com/komuw/naz/pull/155
- this builds on the [earlier work](https://github.com/komuw/naz/pull/153) but now `naz` takes it a step further and will unbind & close connection if it is unable to read the entire SMPP header
- this is done to prevent inconsistency and also to try and be faithful to the smpp spec.


## **version:** v0.6.5
- Simplify Breach log handler: https://github.com/komuw/naz/pull/152

Expand Down
2 changes: 1 addition & 1 deletion naz/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"__title__": "naz",
"__description__": "Naz is an async SMPP client.",
"__url__": "https://github.com/komuw/naz",
"__version__": "v0.6.5",
"__version__": "v0.6.6",
"__author__": "komuW",
"__author_email__": "[email protected]",
"__license__": "MIT",
Expand Down
58 changes: 36 additions & 22 deletions naz/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def __init__(
self.naz_message_protocol_version = "1"

self.current_session_state = SmppSessionState.CLOSED
self._header_pdu_length = 16

self.drain_duration = drain_duration
self.socket_timeout = socket_timeout
Expand Down Expand Up @@ -848,7 +849,7 @@ async def tranceiver_bind(self, log_id: str = "") -> None:
)

# header
command_length = 16 + len(body) # 16 is for headers
command_length = self._header_pdu_length + len(body) # 16 is for headers
command_id = self.command_ids[smpp_command]
# the status for success see section 5.1.3
command_status = SmppCommandStatus.ESME_ROK.value
Expand Down Expand Up @@ -961,7 +962,7 @@ async def enquire_link(self, TESTING: bool = False) -> typing.Union[None, bytes]
body = b""

# header
command_length = 16 + len(body) # 16 is for headers
command_length = self._header_pdu_length + len(body) # 16 is for headers
command_id = self.command_ids[smpp_command]
command_status = 0x00000000 # not used for `enquire_link`
try:
Expand Down Expand Up @@ -1047,7 +1048,7 @@ async def enquire_link_resp(self, sequence_number: int) -> None:
body = b""

# header
command_length = 16 + len(body) # 16 is for headers
command_length = self._header_pdu_length + len(body) # 16 is for headers
command_id = self.command_ids[smpp_command]
command_status = SmppCommandStatus.ESME_ROK.value
sequence_number = sequence_number
Expand Down Expand Up @@ -1106,7 +1107,7 @@ async def unbind_resp(self, sequence_number: int) -> None:
body = b""

# header
command_length = 16 + len(body) # 16 is for headers
command_length = self._header_pdu_length + len(body) # 16 is for headers
command_id = self.command_ids[smpp_command]
command_status = SmppCommandStatus.ESME_ROK.value
sequence_number = sequence_number
Expand Down Expand Up @@ -1153,7 +1154,7 @@ async def deliver_sm_resp(self, sequence_number: int) -> None:
)

# header
command_length = 16 + len(body) # 16 is for headers
command_length = self._header_pdu_length + len(body) # 16 is for headers
command_id = self.command_ids[smpp_command]
command_status = SmppCommandStatus.ESME_ROK.value
sequence_number = sequence_number
Expand Down Expand Up @@ -1346,7 +1347,7 @@ async def _build_submit_sm_pdu(
)

# header
command_length = 16 + len(body) # 16 is for headers
command_length = self._header_pdu_length + len(body) # 16 is for headers
command_id = self.command_ids[smpp_command]
# the status for success see section 5.1.3
command_status = 0x00000000 # not used for `submit_sm`
Expand Down Expand Up @@ -1842,22 +1843,34 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[None, bytes]
)
return None

command_length_header_data = b""
header_data = b""
try:
if typing.TYPE_CHECKING:
# make mypy happy; https://github.com/python/mypy/issues/4805
assert isinstance(self.reader, asyncio.streams.StreamReader)

# todo: look at `pause_reading` and `resume_reading` methods
# `client.reader` and `client.writer` should not have timeouts since they are non-blocking
# https://github.com/komuw/naz/issues/116
command_length_header_data = await self.reader.read(4)

# make sure the header data be 4 bytes
while len(command_length_header_data) != 4:
more_bytes = await self.reader.read(4 - len(command_length_header_data))
command_length_header_data = command_length_header_data + more_bytes

header_data = await self.reader.readexactly(self._header_pdu_length)
except asyncio.IncompleteReadError as e:
# see: https://github.com/komuw/naz/issues/135
self._log(
logging.ERROR,
{
"event": "naz.Client.receive_data",
"stage": "end",
"state": "unable to read exactly {0}bytes of smpp header.".format(
self._header_pdu_length
),
"error": str(e),
},
)
header_data == b""
# close connection. it will be automatically reconnected later
await self._unbind_and_disconnect()
if TESTING:
# offer escape hatch for tests to come out of endless loop
return header_data
except (
ConnectionError,
TimeoutError,
Expand All @@ -1877,7 +1890,7 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[None, bytes]
},
)

if command_length_header_data == b"":
if header_data == b"":
retry_count += 1
poll_read_interval = self._retry_after(retry_count)
self._log(
Expand All @@ -1899,8 +1912,9 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[None, bytes]
# we didn't fail to read from SMSC
retry_count = 0

total_pdu_length = struct.unpack(">I", command_length_header_data)[0]
MSGLEN = total_pdu_length - 4
# first 4bytes of header are the command_length
total_pdu_length = struct.unpack(">I", header_data[:4])[0]
MSGLEN = total_pdu_length - self._header_pdu_length
chunks = []
bytes_recd = 0
while bytes_recd < MSGLEN:
Expand Down Expand Up @@ -1953,7 +1967,7 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[None, bytes]
chunks.append(chunk)
bytes_recd = bytes_recd + len(chunk)

full_pdu_data = command_length_header_data + b"".join(chunks)
full_pdu_data = header_data + b"".join(chunks)
self._log(
logging.DEBUG,
{
Expand Down Expand Up @@ -1982,8 +1996,8 @@ async def _parse_response_pdu(self, pdu: bytes) -> None:
{"event": "naz.Client._parse_response_pdu", "stage": "start", "pdu": log_pdu},
)

header_data = pdu[:16]
body_data = pdu[16:]
header_data = pdu[: self._header_pdu_length]
body_data = pdu[self._header_pdu_length :]
command_id_header_data = header_data[4:8]
command_status_header_data = header_data[8:12]
sequence_number_header_data = header_data[12:16]
Expand Down Expand Up @@ -2324,7 +2338,7 @@ async def unbind(self) -> None:
body = b""

# header
command_length = 16 + len(body) # 16 is for headers
command_length = self._header_pdu_length + len(body) # 16 is for headers
command_id = self.command_ids[smpp_command]
command_status = 0x00000000 # not used for `unbind`
try:
Expand Down
100 changes: 86 additions & 14 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,82 @@ async def mock_coro(*args, **kwargs):
return mock_coro


class MockStreamWriter:
"""
This is a mock of python's StreamWriter;
https://docs.python.org/3.6/library/asyncio-stream.html#asyncio.StreamWriter
"""

def __init__(self, _is_closing=False):
self.transport = self._create_transport(_is_closing=_is_closing)

async def drain(self):
pass

def close(self):
# when this is called, we set the transport to be in closed/closing state
self.transport = self._create_transport(_is_closing=True)

def write(self, data):
pass

def get_extra_info(self, name, default=None):
# when this is called, we set the transport to be in open state.
# this is because this method is called in `naz.Client.connect`
# so it is the only chance we have of 're-establishing' connection
self.transport = self._create_transport(_is_closing=False)

def _create_transport(self, _is_closing):
class MockTransport:
def __init__(self, _is_closing):
self._is_closing = _is_closing

def set_write_buffer_limits(self, n):
pass

def is_closing(self):
return self._is_closing

return MockTransport(_is_closing=_is_closing)


class MockStreamReader:
"""
This is a mock of python's StreamReader;
https://docs.python.org/3.6/library/asyncio-stream.html#asyncio.StreamReader

We mock the reader having a succesful submit_sm_resp PDU.
For the first read we return the first 4bytes,
the second read, we return the remaining bytes.
"""

def __init__(self, pdu):
self.pdu = pdu

async def read(self, n_index=-1):
if n_index == 0:
return b""
blocks = []
blocks.append(self.pdu)
data = b"".join(blocks)
if n_index == 4:
return data[:n_index]
self.data = b"".join(blocks)

async def read(self, n=-1):
if n == 0:
return b""

if n == -1:
_to_read_data = self.data # read all data
_remaining_data = b""
else:
return data[4:]
_to_read_data = self.data[:n]
_remaining_data = self.data[n:]

self.data = _remaining_data
return _to_read_data

async def readexactly(self, n):
_to_read_data = self.data[:n]
_remaining_data = self.data[n:]

if len(_to_read_data) != n:
# unable to read exactly n bytes
raise asyncio.IncompleteReadError(partial=_to_read_data, expected=n)

self.data = _remaining_data
return _to_read_data


class TestClient(TestCase):
Expand Down Expand Up @@ -472,11 +525,9 @@ def test_receving_data(self):
submit_sm_resp_pdu = (
b"\x00\x00\x00\x12\x80\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x030\x00"
)

# TODO: also create a MockStreamWriter
mock_naz_connect.mock.return_value = (
MockStreamReader(pdu=submit_sm_resp_pdu),
"MockStreamWriter",
MockStreamWriter(),
)

reader, writer = self._run(self.cli.connect())
Expand All @@ -485,6 +536,27 @@ def test_receving_data(self):
received_pdu = self._run(self.cli.receive_data(TESTING=True))
self.assertEqual(received_pdu, submit_sm_resp_pdu)

def test_partial_reads_disconnect(self):
"""
test that if we are unable to read the full 16byte smpp header,
then we should close the connection.
"""
with mock.patch("naz.Client.connect", new=AsyncMock()) as mock_naz_connect, mock.patch(
"naz.Client._unbind_and_disconnect", new=AsyncMock()
) as mock_naz_unbind_and_disconnect:
submit_sm_resp_pdu = b"\x00\x00\x00"
mock_naz_connect.mock.return_value = (
MockStreamReader(pdu=submit_sm_resp_pdu),
MockStreamWriter(),
)

reader, writer = self._run(self.cli.connect())
self.cli.reader = reader
self.cli.writer = writer
received_pdu = self._run(self.cli.receive_data(TESTING=True))
self.assertEqual(received_pdu, b"")
self.assertTrue(mock_naz_unbind_and_disconnect.mock.called)

def test_enquire_link_resp(self):
with mock.patch("naz.q.SimpleOutboundQueue.enqueue", new=AsyncMock()) as mock_naz_enqueue:
sequence_number = 7
Expand Down