Skip to content

Commit

Permalink
Moved timeout to only the initial negotiate instead of a global setti…
Browse files Browse the repository at this point in the history
…ng, set NTLM check to verify username and password is set
  • Loading branch information
jborean93 committed Feb 24, 2018
1 parent 0b0094e commit 991210e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 19 deletions.
2 changes: 1 addition & 1 deletion smbprotocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def emit(self, record):
logger = logging.getLogger(__name__)
logger.addHandler(NullHandler())

__version__ = '0.0.1.dev3'
__version__ = '0.0.1.dev4'
27 changes: 11 additions & 16 deletions smbprotocol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,8 +705,7 @@ def __init__(self):

class Connection(object):

def __init__(self, guid, server_name, port=445, require_signing=True,
timeout=60):
def __init__(self, guid, server_name, port=445, require_signing=True):
"""
[MS-SMB2] v53.0 2017-09-15
Expand All @@ -720,15 +719,12 @@ def __init__(self, guid, server_name, port=445, require_signing=True,
:param port: The port to use for the transport, default is 445
:param require_signing: Whether signing is required on SMB messages
sent over this connection
:param timeout: The default connection timeout used when waiting for
a response from the server
"""
log.info("Initialising connection, guid: %s, require_singing: %s, "
"server_name: %s, port: %d"
% (guid, require_signing, server_name, port))
self.server_name = server_name
self.port = port
self.timeout = timeout
self.transport = Tcp(server_name, port)

# Table of Session entries
Expand Down Expand Up @@ -802,7 +798,7 @@ def __init__(self, guid, server_name, port=445, require_signing=True,
# data being read in multiple locations
self.rec_lock = Lock()

def connect(self, dialect=None):
def connect(self, dialect=None, timeout=60):
"""
Will connect to the target server and negotiate the capabilities
with the client. Once setup, the client MUST call the disconnect()
Expand All @@ -813,12 +809,14 @@ def connect(self, dialect=None):
:param dialect: If specified, forces the dialect that is negotiated
with the server, if not set, then the newest dialect supported by
the server is used up to SMB 3.1.1
:param timeout: The timeout in seconds to wait for the initial
negotiation process to complete
"""
log.info("Setting up transport connection")
self.transport.connect()

log.info("Starting negotiation with SMB server")
smb_response = self._send_smb2_negotiate(dialect)
smb_response = self._send_smb2_negotiate(dialect, timeout)
log.info("Negotiated dialect: %s"
% str(smb_response['dialect_revision']))
self.dialect = smb_response['dialect_revision'].get_value()
Expand Down Expand Up @@ -985,13 +983,10 @@ def receive(self, request, wait=True, timeout=None):
:param wait: Wait for the final response in the case of a
STATUS_PENDING response, the pending response is returned in the
case of wait=False
:param timeout: Override the default timeout used to setup the
Connection, will raise an SMBException if the timeout is reached.
:param timeout: Set a timeout used while waiting for a response from
the server
:return: SMB2HeaderResponse of the received message
"""
rec_timeout = self.timeout
if timeout:
rec_timeout = timeout
start_time = time.time()

# check if we have received a response
Expand All @@ -1003,10 +998,10 @@ def receive(self, request, wait=True, timeout=None):
status != NtStatus.STATUS_PENDING):
break
current_time = time.time() - start_time
if current_time > rec_timeout:
if timeout and (current_time > timeout):
error_msg = "Connection timeout of %d seconds exceeded while" \
" waiting for a response from the server" \
% rec_timeout
% timeout
raise smbprotocol.exceptions.SMBException(error_msg)

response = request.response
Expand Down Expand Up @@ -1217,7 +1212,7 @@ def _decrypt(self, message):
dec_message = c.decrypt(nonce, enc_message, message.pack()[20:52])
return dec_message

def _send_smb2_negotiate(self, dialect):
def _send_smb2_negotiate(self, dialect, timeout):
self.salt = os.urandom(32)

if dialect is None:
Expand Down Expand Up @@ -1290,7 +1285,7 @@ def _send_smb2_negotiate(self, dialect):
request = self.send(neg_req)
self.preauth_integrity_hash_value.append(request.message)

response = self.receive(request)
response = self.receive(request, timeout)
log.info("Receiving SMB2 Negotiate response")
log.debug(str(response))
self.preauth_integrity_hash_value.append(response)
Expand Down
7 changes: 7 additions & 0 deletions smbprotocol/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,13 @@ def _smb3kdf(self, ki, label, context):
class NtlmContext(object):

def __init__(self, username, password):
if username is None:
raise SMBAuthenticationError("The username must be set when using "
"NTLM authentication")
if password is None:
raise SMBAuthenticationError("The password must be set when using "
"NTLM authentication")

# try and get the domain part from the username
log.info("Setting up NTLM Security Context for user %s" % username)
try:
Expand Down
37 changes: 35 additions & 2 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from smbprotocol.connection import Connection, Dialects, SecurityMode
from smbprotocol.exceptions import SMBAuthenticationError, SMBException
from smbprotocol.session import Session, SMB2Logoff, SMB2SessionSetupRequest, \
SMB2SessionSetupResponse
from smbprotocol.session import NtlmContext, Session, SMB2Logoff, \
SMB2SessionSetupRequest, SMB2SessionSetupResponse

from .utils import smb_real

Expand Down Expand Up @@ -104,6 +104,39 @@ def test_parse_message(self):
assert actual['reserved'].get_value() == 0


class TestNtlmContext(object):

def test_no_username_fail(self):
with pytest.raises(SMBException) as exc:
NtlmContext(None, None)
assert str(exc.value) == "The username must be set when using NTLM " \
"authentication"

def test_no_password_fail(self):
with pytest.raises(SMBException) as exc:
NtlmContext("username", None)
assert str(exc.value) == "The password must be set when using NTLM " \
"authentication"

def test_username_without_domain(self):
actual = NtlmContext("username", "password")
assert actual.domain == ""
assert actual.username == "username"
assert actual.password == "password"

def test_username_in_netlogon_form(self):
actual = NtlmContext("DOMAIN\\username", "password")
assert actual.domain == "DOMAIN"
assert actual.username == "username"
assert actual.password == "password"

def test_username_in_upn_form(self):
actual = NtlmContext("[email protected]", "password")
assert actual.domain == ""
assert actual.username == "[email protected]"
assert actual.password == "password"


class TestSession(object):

def test_dialect_2_0_2(self, smb_real):
Expand Down

0 comments on commit 991210e

Please sign in to comment.