diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..5fee3259 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,39 @@ +sudo: required + +language: python + +python: +- "2.6" +- "2.7" +- "3.4" +- "3.5" +- "3.6" + +services: +- docker + +env: + global: + - SMB_USER: smbuser + - SMB_PASSWORD: smbpassword + - SMB_SERVER: 127.0.0.1 + - SMB_SHARE: share + - SMB_PORT: 445 + +install: +- | + if [[ $TRAVIS_PYTHON_VERSION == '2.6' ]]; then + export SMB_SKIP="True"; + else + docker run -d -p $SMB_PORT:445 -v $(pwd)/build-scripts:/app -w /app -e SMB_USER=$SMB_USER -e SMB_PASSWORD=$SMB_PASSWORD -e SMB_SHARE=$SMB_SHARE centos:7 /bin/bash /app/setup_samba.sh; + fi +- pip install -U pip setuptools +- pip install . +- pip install -r requirements-test.txt +- pip install python-coveralls + +script: +- py.test -v --instafail --pep8 --cov smbprotocol --cov-report term-missing + +after_success: +- coveralls diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..c180fed7 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,15 @@ +# Changelog + +## 0.0.1 (Unreleased) + +Initial release of smbprotocol, it contains the following features + +* Support for Dialect 2.0.2 to 3.1.1 +* Supports message encryption and signing +* Works with both NTLM and Kerberos auth (latter requiring a non-windows + library) +* Open files, directories and pipes +* Open command with create_contexts to set extra attributes on an open +* Read/Write the files +* Send IOCTL commands +* Sending of multiple messages in one packet (compounding) diff --git a/LICENSE b/LICENSE index fb5ee9f2..d56cbb07 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2017 Jordan Borean +Copyright (c) 2017 Jordan Borean, Red Hat Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..99bb13cf --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include LICENSE +include CHANGELOG.md \ No newline at end of file diff --git a/README.md b/README.md index 8654332f..eb36804e 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,191 @@ # smbprotocol -Python SMBv2 and v3 Client +SMBv2 and v3 Client for both Python 2 and 3. -This is in progress and is designed to be a Python library that can interact with SMBv2 and SMBv3 servers using the [MS-SMB2](https://msdn.microsoft.com/en-us/library/cc246482.aspx) protocol. +[![License](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/jborean93/smbprotocol/blob/master/LICENSE) +[![Travis Build](https://travis-ci.org/jborean93/smbprotocol.svg)](https://travis-ci.org/jborean93/smbprotocol) +[![AppVeyor Build](https://ci.appveyor.com/api/projects/status/github/jborean93/smbprotocol?svg=true)](https://ci.appveyor.com/project/jborean93/smbprotocol) +[![Coverage](https://coveralls.io/repos/jborean93/smbprotocol/badge.svg)](https://coveralls.io/r/jborean93/smbprotocol) + +SMB is a network file sharing protocol and has numerous iterations over the +years. This library implements the SMBv2 and SMBv3 protocol based on the +[MS-SMB2](https://msdn.microsoft.com/en-us/library/cc246482.aspx) document. + + +## Features + +* Negotiation of the SMB 2.0.2 protocol to SMB 3.1.1 (Windows 10/Server 2016) +* Authentication with both NTLM and Kerberos +* Message signing +* Message encryption (SMB 3.x.x+) +* Connect to a Tree/Share +* Opening of files, pipes and directories +* Set create contexts when opening files +* Read and writing of files and pipes +* Sending IOCTL commands +* Sending of multiple messages in one packet (compounding) + +This is definitely not feature complete as SMB is quite a complex protocol, see +backlog for features that would be nice to have in this library. + + +## Requirements + +* Python 2.6, 2.7, 3.4-3.6 +* For Kerberos auth, the [python-gssapi](https://github.com/pythongssapi/python-gssapi) package (see below) + +The python-gssapi library is required to support Kerberos authentication but +`smbprotocol` requires the GSSAPI GGF extension to support things like +message encryption. To test out if the installed version of python-gsspapi +can be used you can run the python commands in a Python console; + +``` +try: + from gssapi.raw import inquire_sec_context_by_oid + print("python-gssapi extension is available") +except ImportError as exc: + print("python-gssapi extension is not available: %s" % str(exc)) +``` + +If it isn't available, then either a newer version of the system's gssapi +implementation needs to be setup and python-gssapi compiled against that newer +version. + + +## Installation + +To install smbprotocol, simply run + +`pip install smbprotocol` + +This will download the required packages that are used in this package and get +your Python environment ready to go. + + +## Additional Info + +One of the first steps as part of the SMB protocol is to negotiate the dialect +used and other features that are available. Currently smbprotocol supports +the following dialects; + +* `2.0.0`: Added with Server 2008/Windows Vista +* `2.1.0`: Added with Server 2008 R2/Windows 7 +* `3.0.0`: Added with Server 2012/Windows 8 +* `3.0.2`: Added with Server 2012 R2/Windows 8.1 +* `3.1.1`: Added with Server 2016/Windows10 + +Each dialect adds in more features to the protocol where some are minor but +some are major. One major changes is in Dialect 3.x where it added message +encryption. Message encryption is set to True by default and needs to be +overridden when creating a Session object for the older dialects. + +By default, the negotiation process will use the latest dialect that is +supported by the server but this can be overridden if required. When this is +done by the following code + +``` +import uuid + +from smbprotocol.connection import Connection, Dialects + +connection = Connection(uuid.uuid4(), "server", 445) +connection.connect(Dialects.SMB_3_0_2) +``` + +While you shouldn't want to downgrade to an earlier version, this does allow +you to set a minimum dialect version if required. + + +## Examples + +Currently the existing classes expose a very low level interface to the SMB +protocol which can make things quite complex for people starting to use this +package. I do plan on making a high-level interface to make things easier for +users but that's in the backlog. + +For now, the `examples` folder contains some examples of how this package can +be used. + + +## Logging + +This library makes use of the builtin Python logging facilities. Log messages +are logged to the `smbprotocol` named logger as well as `smbprotocol.*` where +`*` is each python script in the `smbprotocol` directory. + +These logs are really useful when debugging issues as they give you a more +step by step snapshot of what it is doing and what may be going wrong. The +debug side will also print out a human readable string of each SMB packet that +is sent out from the client so it can get very verbose. + + +## Testing + +To this module, you need to install some pre-requisites first. This can be done +by running; + +``` +pip install -r requirements-test.txt + +# you can also run tox by installing tox +pip install tox +``` + +From there to run the basic tests run; + +``` +py.test -v --pep8 --cov smbprotocol --cov-report term-missing + +# or with tox 2.7, 2.7, 3.4, 3.5, and 3.6 +tox +``` + +There are extra tests that only run when certain environment variables are set. +To run these tests set the following variables; + +* `SMB_USER`: The username to authenticate with +* `SMB_PASSWORD`: The password to authenticate with +* `SMB_SERVER`: The IP or hostname of the server to authenticate with +* `SMB_PORT`: The port the SMB server is listening on, default is `445` +* `SMB_SHARE`: The name of the share to connect to, a share with this name must exist as well as a share with the name`$SMB_SHARE-encrypted` must also exist that forces encryption + +From here running `tox` or `py.test` with these environment variables set will +activate the integration tests. + +To set up a Windows host that will work with these tests run the following in +PowerShell; + +```powershell +New-Item -Path C:\share -ItemType Directory > $null +New-Item -Path C:\share-encrypted -ItemType Directory > $null +New-SmbShare -Name $env:SMB_SHARE -Path C:\share -EncryptData $false -FullAccess Everyone > $null +New-SmbShare -Name "$($env:SMB_SHARE)-encrypted" -Path C:\share-encrypted -EncryptData $true -FullAccess Everyone > $null +``` + +This requires either Windows 10 or Server 2016 as they support Dialect 3.1.1 +which is required by the tests. + +If you don't have access to a Windows host, you can use Docker to setup a +Samba container and use that as part of the tests. To do so run the following +bash commands; + +```bash +export SMB_USER=smbuser +export SMB_PASSWORD=smbpassword +export SMB_PORT=445 +export SMB_SERVER=127.0.0.1 +export SMB_SHARE=share + +docker run -d -p $SMB_PORT:445 -v $(pwd)/build-scripts:/app -w /app -e SMB_USER=$SMB_USER -e SMB_PASSWORD=$SMB_PASSWORD -e SMB_SHARE=$SMB_SHARE centos:7 /bin/bash /app/setup_samba.sh; +``` + + +## Backlog + +Here is a list of features that I would like to incorporate, PRs are welcome +if you want to implement them yourself; + +* SSPI integration for Windows and Kerberos authentication +* Test and support DFS mounts and not just server shares +* Multiple channel support to speed up large data transfers +* Create an easier API on top of the `raw` SMB calls that currently exist +* Lots and lots more... diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 00000000..b2f7384a --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,47 @@ +# Windows Server 2016 +image: Visual Studio 2017 + +environment: + global: + # SDK v7.0 MSVC Express 2008's SetEnv.cmd script will fail if the + # /E:ON and /V:ON options are not enabled in the batch script intepreter + # See: http://stackoverflow.com/a/13751649/163740 + WITH_COMPILER: "cmd /E:ON /V:ON /C .\\build-scripts\\run_with_compiler.cmd" + SMB_SERVER: 127.0.0.1 + SMB_SHARE: share + SMB_PORT: 445 + matrix: + # https://www.appveyor.com/docs/installed-software/#python + - PYTHON: Python27 + - PYTHON: Python27-x64 + - PYTHON: Python34 + - PYTHON: Python34-x64 + - PYTHON: Python35 + - PYTHON: Python35-x64 + - PYTHON: Python36 + - PYTHON: Python36-x64 + +init: +- ps: | + $ErrorActionPreference = "Stop" + # Override default Python version/architecture + $env:Path="C:\$env:PYTHON;C:\$env:PYTHON\Scripts;$env:PATH" + python -c "import platform; print('Python', platform.python_version(), platform.architecture()[0])" + New-Item -Path C:\share -ItemType Directory > $null + New-Item -Path C:\share-encrypted -ItemType Directory > $null + New-SmbShare -Name $env:SMB_SHARE -Path C:\share -EncryptData $false -FullAccess Everyone > $null + New-SmbShare -Name "$($env:SMB_SHARE)-encrypted" -Path C:\share-encrypted -EncryptData $true -FullAccess Everyone > $null + + $env:SMB_USER = $($env:USERNAME) + $env:SMB_PASSWORD = [Microsoft.Win32.Registry]::GetValue("HKEY_LOCAL_MACHINE\Software\Microsoft\Windows NT\CurrentVersion\Winlogon", "DefaultPassword", '') + +install: +- ps: | + pip install -U pip setuptools + pip install . + pip install -r requirements-test.txt + +build: off # Do not run MSBuild, build stuff at install step + +test_script: +- ps: py.test -v --instafail --pep8 --cov smbprotocol --cov-report term-missing diff --git a/build-scripts/run_with_compiler.bat b/build-scripts/run_with_compiler.bat new file mode 100644 index 00000000..3a472bc8 --- /dev/null +++ b/build-scripts/run_with_compiler.bat @@ -0,0 +1,47 @@ +:: To build extensions for 64 bit Python 3, we need to configure environment +:: variables to use the MSVC 2010 C++ compilers from GRMSDKX_EN_DVD.iso of: +:: MS Windows SDK for Windows 7 and .NET Framework 4 (SDK v7.1) +:: +:: To build extensions for 64 bit Python 2, we need to configure environment +:: variables to use the MSVC 2008 C++ compilers from GRMSDKX_EN_DVD.iso of: +:: MS Windows SDK for Windows 7 and .NET Framework 3.5 (SDK v7.0) +:: +:: 32 bit builds do not require specific environment configurations. +:: +:: Note: this script needs to be run with the /E:ON and /V:ON flags for the +:: cmd interpreter, at least for (SDK v7.0) +:: +:: More details at: +:: https://github.com/cython/cython/wiki/64BitCythonExtensionsOnWindows +:: http://stackoverflow.com/a/13751649/163740 +:: +:: Author: Olivier Grisel +:: License: CC0 1.0 Universal: http://creativecommons.org/publicdomain/zero/1.0/ +@ECHO OFF + +SET COMMAND_TO_RUN=%* +SET WIN_SDK_ROOT=C:\Program Files\Microsoft SDKs\Windows + +SET MAJOR_PYTHON_VERSION="%PYTHON_VERSION:~0,1%" +IF %MAJOR_PYTHON_VERSION% == "2" ( + SET WINDOWS_SDK_VERSION="v7.0" +) ELSE IF %MAJOR_PYTHON_VERSION% == "3" ( + SET WINDOWS_SDK_VERSION="v7.1" +) ELSE ( + ECHO Unsupported Python version: "%MAJOR_PYTHON_VERSION%" + EXIT 1 +) + +IF "%PYTHON_ARCH%"=="64" ( + ECHO Configuring Windows SDK %WINDOWS_SDK_VERSION% for Python %MAJOR_PYTHON_VERSION% on a 64 bit architecture + SET DISTUTILS_USE_SDK=1 + SET MSSdk=1 + "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Setup\WindowsSdkVer.exe" -q -version:%WINDOWS_SDK_VERSION% + "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Bin\SetEnv.cmd" /x64 /release + ECHO Executing: %COMMAND_TO_RUN% + call %COMMAND_TO_RUN% || EXIT 1 +) ELSE ( + ECHO Using default MSVC build environment for 32 bit architecture + ECHO Executing: %COMMAND_TO_RUN% + call %COMMAND_TO_RUN% || EXIT 1 +) diff --git a/build-scripts/setup_samba.sh b/build-scripts/setup_samba.sh new file mode 100755 index 00000000..abf3701b --- /dev/null +++ b/build-scripts/setup_samba.sh @@ -0,0 +1,43 @@ +# install samba on a Centos server +yum install samba -y + +# set basic SMB configuration +cat > /etc/samba/smb.conf << EOL +[global] +workgroup = WORKGROUP +valid users = @smbgroup + +[$SMB_SHARE] +comment = Test Samba Share +path = /srv/samba/$SMB_SHARE +browsable = yes +guest ok = no +read only = no +create mask = 0755 + +[${SMB_SHARE}-encrypted] +command = Test Encrypted Samba Share +path = /srv/samba/${SMB_SHARE}-encrypted +browsable = yes +guest ok = no +read only = no +create mask = 0755 +smb encrypt = required +EOL + +# create smb user +groupadd smbgroup +useradd $SMB_USER -G smbgroup +(echo $SMB_PASSWORD; echo $SMB_PASSWORD) | smbpasswd -s -a $SMB_USER + +# create smb share and configure permissions +mkdir -p /srv/samba/$SMB_SHARE +chmod -R 0755 /srv/samba/$SMB_SHARE +chown -R $SMB_USER:smbgroup /srv/samba/$SMB_SHARE + +mkdir -p /srv/samba/${SMB_SHARE}-encrypted +chmod -R 0755 /srv/samba/${SMB_SHARE}-encrypted +chown -R $SMB_USER:smbgroup /srv/samba/${SMB_SHARE}-encrypted + +# run smb service +/usr/sbin/smbd -F -S < /dev/null diff --git a/examples/directory-management.py b/examples/directory-management.py new file mode 100644 index 00000000..e03cb8ef --- /dev/null +++ b/examples/directory-management.py @@ -0,0 +1,83 @@ +import uuid + +from smbprotocol.connection import Connection +from smbprotocol.session import Session +from smbprotocol.open import CreateDisposition, CreateOptions, \ + DirectoryAccessMask, FileAttributes, FileInformationClass, \ + FilePipePrinterAccessMask, ImpersonationLevel, Open, ShareAccess +from smbprotocol.tree import TreeConnect + +server = "127.0.0.1" +port = 1445 +username = "smbuser" +password = "smbpassword1" +share = r"\\%s\share" % server +dir_name = "directory" + +connection = Connection(uuid.uuid4(), server, port) +connection.connect() + +try: + session = Session(connection, username, password) + session.connect() + tree = TreeConnect(session, share) + tree.connect() + + # ensure directory is created + dir_open = Open(tree, dir_name) + dir_open.open( + ImpersonationLevel.Impersonation, + DirectoryAccessMask.GENERIC_READ | DirectoryAccessMask.GENERIC_WRITE, + FileAttributes.FILE_ATTRIBUTE_DIRECTORY, + ShareAccess.FILE_SHARE_READ | ShareAccess.FILE_SHARE_WRITE, + CreateDisposition.FILE_OPEN_IF, + CreateOptions.FILE_DIRECTORY_FILE + ) + + # create some files in dir and query the contents as part of a compound + # request + directory_file = Open(tree, r"%s\file.txt" % dir_name) + directory_file.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + ShareAccess.FILE_SHARE_READ, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE | + CreateOptions.FILE_DELETE_ON_CLOSE) + + compound_messages = [ + directory_file.write("Hello World".encode('utf-8'), 0, send=False), + dir_open.query_directory("*", + FileInformationClass.FILE_NAMES_INFORMATION, + send=False), + directory_file.close(False, send=False), + dir_open.close(False, send=False) + ] + requests = connection.send_compound([x[0] for x in compound_messages], + session.session_id, + tree.tree_connect_id) + responses = [] + for i, request in enumerate(requests): + response = compound_messages[i][1](request) + responses.append(response) + + dir_files = [] + for dir_file in responses[1]: + dir_files.append(dir_file['file_name'].get_value().decode('utf-16-le')) + + print("Directory '%s\\%s' contains the files: '%s'" + % (share, dir_name, "', '".join(dir_files))) + + # delete a directory (note the dir needs to be empty to delete on close) + dir_open = Open(tree, dir_name) + dir_open.open( + ImpersonationLevel.Impersonation, + DirectoryAccessMask.DELETE, + FileAttributes.FILE_ATTRIBUTE_DIRECTORY, + 0, + CreateDisposition.FILE_OPEN, + CreateOptions.FILE_DIRECTORY_FILE | CreateOptions.FILE_DELETE_ON_CLOSE + ) + dir_open.close(False) +finally: + connection.disconnect(True) diff --git a/examples/file-management.py b/examples/file-management.py new file mode 100644 index 00000000..acb5830a --- /dev/null +++ b/examples/file-management.py @@ -0,0 +1,105 @@ +import uuid + +from smbprotocol.connection import Connection +from smbprotocol.create_contexts import CreateContextName, \ + SMB2CreateContextRequest, SMB2CreateQueryMaximalAccessRequest +from smbprotocol.security_descriptor import AccessAllowedAce, AccessMask, \ + AclPacket, SIDPacket, SMB2CreateSDBuffer +from smbprotocol.session import Session +from smbprotocol.structure import FlagField +from smbprotocol.open import CreateDisposition, CreateOptions, \ + FileAttributes, FilePipePrinterAccessMask, ImpersonationLevel, Open, \ + ShareAccess +from smbprotocol.tree import TreeConnect + +server = "127.0.0.1" +port = 1445 +username = "smbuser" +password = "smbpassword1" +share = r"\\%s\share" % server +file_name = "file-test.txt" + +connection = Connection(uuid.uuid4(), server, port) +connection.connect() + +try: + session = Session(connection, username, password) + session.connect() + tree = TreeConnect(session, share) + tree.connect() + + # ensure file is created, get maximal access, and set everybody read access + max_req = SMB2CreateContextRequest() + max_req['buffer_name'] = \ + CreateContextName.SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST + max_req['buffer_data'] = SMB2CreateQueryMaximalAccessRequest() + + # create security buffer that sets the ACL for everyone to have read access + everyone_sid = SIDPacket() + everyone_sid.from_string("S-1-1-0") + + ace = AccessAllowedAce() + ace['mask'] = AccessMask.GENERIC_ALL + ace['sid'] = everyone_sid + + acl = AclPacket() + acl['aces'] = [ace] + + sec_desc = SMB2CreateSDBuffer() + sec_desc.set_dacl(acl) + sd_buffer = SMB2CreateContextRequest() + sd_buffer['buffer_name'] = CreateContextName.SMB2_CREATE_SD_BUFFER + sd_buffer['buffer_data'] = sec_desc + + create_contexts = [ + max_req, + sd_buffer + ] + + file_open = Open(tree, file_name) + open_info = file_open.open( + ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.GENERIC_READ | + FilePipePrinterAccessMask.GENERIC_WRITE, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + ShareAccess.FILE_SHARE_READ | ShareAccess.FILE_SHARE_WRITE, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE, + create_contexts + ) + + # as the raw structure 'maximal_access' is an IntField, we create our own + # flag field, set the value and get the human readble string + max_access = FlagField( + size=4, + flag_type=FilePipePrinterAccessMask, + flag_strict=False + ) + max_access.set_value(open_info[0]['maximal_access'].get_value()) + print("Maximum access mask for file %s\\%s: %s" + % (share, file_name, str(max_access))) + + # write to a file + text = "Hello World, what a nice day to play with SMB" + file_open.write(text.encode('utf-8'), 0) + + # read from a file + file_text = file_open.read(0, 1024) + print("Text of file %s\\%s: %s" + % (share, file_name, file_text.decode('utf-8'))) + file_open.close(False) + + # delete a file + file_open = Open(tree, file_name) + file_open.open( + ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.DELETE, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OPEN, + CreateOptions.FILE_NON_DIRECTORY_FILE | + CreateOptions.FILE_DELETE_ON_CLOSE + ) + file_open.close(False) +finally: + connection.disconnect(True) diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 00000000..425931ef --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,4 @@ +pytest<=3.2.5 +pytest-cov +pytest-pep8 +pytest-instafail \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..b6ebef81 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[bdist_wheel] +universal = 1 + +[tool:pytest] +pep8ignore = setup.py E501 diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..0c6465a1 --- /dev/null +++ b/setup.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# coding: utf-8 + +from setuptools import setup +from smbprotocol import __version__ + +# PyPi supports only reStructuredText, so pandoc should be installed +# before uploading package +try: + import pypandoc + long_description = pypandoc.convert('README.md', 'rst') +except ImportError: + long_description = '' + + +setup( + name='smbprotocol', + version=__version__, + packages=['smbprotocol'], + install_requires=[ + 'cryptography>=2.0', + 'ntlm-auth', + 'pyasn1', + 'six', + ], + extras_require={ + ':python_version<"2.7"': [ + 'ordereddict' + ], + ':sys_platform!="win32"': [ + 'gssapi>=1.4.1' + ] + }, + author='Jordan Borean', + author_email='jborean93@gmail.com', + url='https://github.com/jborean93/smbprotocol', + description='Interact with a server using the SMB 2/3 Protocol', + long_description=long_description, + keywords='smb smb2 smb3 cifs python', + license='MIT', + classifiers=[ + 'Development Status :: 1 - Planning', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.6', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + ], +) diff --git a/smbprotocol/__init__.py b/smbprotocol/__init__.py new file mode 100644 index 00000000..4ee5b271 --- /dev/null +++ b/smbprotocol/__init__.py @@ -0,0 +1,13 @@ +import logging + +try: + from logging import NullHandler +except ImportError: # pragma: no cover + class NullHandler(logging.Handler): + def emit(self, record): + pass + +logger = logging.getLogger(__name__) +logger.addHandler(NullHandler()) + +__version__ = '0.0.1.dev0' diff --git a/smbprotocol/connection.py b/smbprotocol/connection.py new file mode 100644 index 00000000..c76ca8f9 --- /dev/null +++ b/smbprotocol/connection.py @@ -0,0 +1,1343 @@ +from __future__ import division + +import copy +import hashlib +import hmac +import logging +import math +import os +import struct +import sys +from datetime import datetime +from multiprocessing.dummy import Lock + +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import cmac +from cryptography.hazmat.primitives.ciphers import aead, algorithms + +import smbprotocol.exceptions +from smbprotocol.structure import BytesField, DateTimeField, EnumField, \ + FlagField, IntField, ListField, Structure, StructureField, UuidField +from smbprotocol.transport import Tcp + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + +if sys.version[0] == '2': + from Queue import Empty +else: + from queue import Empty + +log = logging.getLogger(__name__) + + +class Commands(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.1.2 SMB2 Packet Header - SYNC Command + The command code of an SMB2 packet, it is used in the packet header. + """ + SMB2_NEGOTIATE = 0x0000 + SMB2_SESSION_SETUP = 0x0001 + SMB2_LOGOFF = 0x0002 + SMB2_TREE_CONNECT = 0x0003 + SMB2_TREE_DISCONNECT = 0x0004 + SMB2_CREATE = 0x0005 + SMB2_CLOSE = 0x0006 + SMB2_FLUSH = 0x0007 + SMB2_READ = 0x0008 + SMB2_WRITE = 0x0009 + SMB2_LOCK = 0x000A + SMB2_IOCTL = 0x000B + SMB2_CANCEL = 0x000C + SMB2_ECHO = 0x000D + SMB2_QUERY_DIRECTORY = 0x000E + SMB2_CHANGE_NOTIFY = 0x000F + SMB2_QUERY_INFO = 0x0010 + SMB2_SET_INFO = 0x0011 + SMB2_OPLOCK_BREAK = 0x0012 + + +class Smb2Flags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.1.2 SMB2 Packet Header - SYNC Flags + Indicates various processing rules that need to be done on the SMB2 packet. + """ + SMB2_FLAGS_SERVER_TO_REDIR = 0x00000001 + SMB2_FLAGS_ASYNC_COMMAND = 0x00000002 + SMB2_FLAGS_RELATED_OPERATIONS = 0x00000004 + SMB2_FLAGS_SIGNED = 0x00000008 + SMB2_FLAGS_PRIORITY_MASK = 0x00000070 + SMB2_FLAGS_DFS_OPERATIONS = 0x10000000 + SMB2_FLAGS_REPLAY_OPERATIONS = 0x20000000 + + +class SecurityMode(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.3 SMB2 NEGOTIATE Request SecurityMode + Indicates whether SMB signing is enabled or required by the client. + """ + SMB2_NEGOTIATE_SIGNING_ENABLED = 0x0001 + SMB2_NEGOTIATE_SIGNING_REQUIRED = 0x0002 + + +class Capabilities(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.3 SMB2 NEGOTIATE Request Capabilities + Used in SMB3.x and above, used to specify the capabilities supported. + """ + SMB2_GLOBAL_CAP_DFS = 0x00000001 + SMB2_GLOBAL_CAP_LEASING = 0x00000002 + SMB2_GLOBAL_CAP_MTU = 0x00000004 + SMB2_GLOBAL_CAP_MULTI_CHANNEL = 0x00000008 + SMB2_GLOBAL_CAP_PERSISTENT_HANDLES = 0x00000010 + SMB2_GLOBAL_CAP_DIRECTORY_LEASING = 0x00000020 + SMB2_GLOBAL_CAP_ENCRYPTION = 0x00000040 + + +class Dialects(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.3 SMB2 NEGOTIATE Request Dialects + 16-bit integeres specifying an SMB2 dialect that is supported. 0x02FF is + used in the SMBv1 negotiate request to say that dialects greater than + 2.0.2 is supported. + """ + SMB_2_0_2 = 0x0202 + SMB_2_1_0 = 0x0210 + SMB_3_0_0 = 0x0300 + SMB_3_0_2 = 0x0302 + SMB_3_1_1 = 0x0311 + SMB_2_WILDCARD = 0x02FF + + +class NegotiateContextType(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.3.1 SMB2 NEGOTIATE_CONTENT Request ContextType + Specifies the type of context in an SMB2 NEGOTIATE_CONTEXT message. + """ + SMB2_PREAUTH_INTEGRITY_CAPABILITIES = 0x0001 + SMB2_ENCRYPTION_CAPABILITIES = 0x0002 + + +class HashAlgorithms(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.3.1.1 SMB2_PREAUTH_INTEGRITY_CAPABILITIES + 16-bit integer IDs that specify the integrity hash algorithm supported + """ + SHA_512 = 0x0001 + + @staticmethod + def get_algorithm(hash): + return { + HashAlgorithms.SHA_512: hashlib.sha512 + }[hash] + + +class Ciphers(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.3.1.2 SMB2_ENCRYPTION_CAPABILITIES + 16-bit integer IDs that specify the supported encryption algorithms. + """ + AES_128_CCM = 0x0001 + AES_128_GCM = 0x0002 + + @staticmethod + def get_cipher(cipher): + return { + Ciphers.AES_128_CCM: aead.AESCCM, + Ciphers.AES_128_GCM: aead.AESGCM + }[cipher] + + @staticmethod + def get_supported_ciphers(): + supported_ciphers = [] + try: + aead.AESGCM(b"\x00" * 16) + supported_ciphers.append(Ciphers.AES_128_GCM) + except UnsupportedAlgorithm: # pragma: no cover + pass + try: + aead.AESCCM(b"\x00" * 16) + except UnsupportedAlgorithm: # pragma: no cover + pass + return supported_ciphers + + +class NtStatus(object): + """ + [MS-ERREF] https://msdn.microsoft.com/en-au/library/cc704588.aspx + + 2.3.1 NTSTATUS Values + These values are set in the status field of an SMB2Header response. This is + not an exhaustive list but common values that are returned. + """ + STATUS_SUCCESS = 0x00000000 + STATUS_PENDING = 0x00000103 + STATUS_EA_LIST_INCONSISTENT = 0x80000014 + STATUS_STOPPED_ON_SYMLINK = 0x8000002D + STATUS_INVALID_PARAMETER = 0xC000000D + STATUS_END_OF_FILE = 0xC0000011 + STATUS_MORE_PROCESSING_REQUIRED = 0xC0000016 + STATUS_ACCESS_DENIED = 0xC0000022 + STATUS_BUFFER_TOO_SMALL = 0xC0000023 + STATUS_OBJECT_NAME_NOT_FOUND = 0xC0000034 + STATUS_OBJECT_NAME_COLLISION = 0xC0000035 + STATUS_OBJECT_PATH_INVALID = 0xC0000039 + STATUS_OBJECT_PATH_NOT_FOUND = 0xC000003A + STATUS_OBJECT_PATH_SYNTAX_BAD = 0xC000003B + STATUS_SHARING_VIOLATION = 0xC0000043 + STATUS_EAS_NOT_SUPPORTED = 0xC000004F + STATUS_EA_TOO_LARGE = 0xC0000050 + STATUS_NONEXISTENT_EA_ENTRY = 0xC0000051 + STATUS_NO_EAS_ON_FILE = 0xC0000052 + STATUS_EA_CORRUPT_ERROR = 0xC0000053 + STATUS_LOGON_FAILURE = 0xC000006D + STATUS_PASSWORD_EXPIRED = 0xC0000071 + STATUS_INSUFFICIENT_RESOURCES = 0xC000009A + STATUS_PIPE_BUSY = 0xC00000AE + STATUS_FILE_IS_A_DIRECTORY = 0xC00000BA + STATUS_NOT_SUPPORTED = 0xC00000BB + STATUS_BAD_NETWORK_NAME = 0xC00000CC + STATUS_REQUEST_NOT_ACCEPTED = 0xC00000D0 + STATUS_INTERNAL_ERROR = 0xC00000E5 + STATUS_NOT_A_DIRECTORY = 0xC0000103 + STATUS_CANNOT_DELETE = 0xC0000121 + STATUS_FILE_CLOSED = 0xC0000128 + STATUS_PIPE_BROKEN = 0xC000014B + STATUS_USER_SESSION_DELETED = 0xC0000203 + + +class SMB2HeaderRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.1.2 SMB2 Packet Header - SYNC + This is the header definition that contains the ChannelSequence/Reserved + instead of the Status field used for a Packet request. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('protocol_id', BytesField( + size=4, + default=b"\xfeSMB", + )), + ('structure_size', IntField( + size=2, + default=64, + )), + ('credit_charge', IntField(size=2)), + ('channel_sequence', IntField(size=2)), + ('reserved', IntField(size=2)), + ('command', EnumField( + size=2, + enum_type=Commands + )), + ('credit_request', IntField(size=2)), + ('flags', FlagField( + size=4, + flag_type=Smb2Flags, + )), + ('next_command', IntField(size=4)), + ('message_id', IntField(size=8)), + ('process_id', IntField(size=4)), + ('tree_id', IntField(size=4)), + ('session_id', IntField(size=8)), + ('signature', BytesField( + size=16, + default=b"\x00" * 16, + )), + ('data', BytesField()) + ]) + super(SMB2HeaderRequest, self).__init__() + + +class SMB2HeaderResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.1.2 SMB2 Packet Header - SYNC + The header definition for an SMB Response that contains the Status field + instead of the ChannelSequence/Reserved used for a Packet response. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('protocol_id', BytesField( + size=4, + default=b'\xfeSMB', + )), + ('structure_size', IntField( + size=2, + default=64, + )), + ('credit_charge', IntField(size=2)), + ('status', EnumField( + size=4, + enum_type=NtStatus, + enum_strict=False + )), + ('command', EnumField( + size=2, + enum_type=Commands + )), + ('credit_response', IntField(size=2)), + ('flags', FlagField( + size=4, + flag_type=Smb2Flags, + )), + ('next_command', IntField(size=4)), + ('message_id', IntField(size=8)), + ('reserved', IntField(size=4)), + ('tree_id', IntField(size=4)), + ('session_id', IntField(size=8)), + ('signature', BytesField( + size=16, + default=b"\x00" * 16, + )), + ('data', BytesField()), + ]) + super(SMB2HeaderResponse, self).__init__() + + +class SMB2NegotiateRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.3 SMB2 Negotiate Request + The SMB2 NEGOTIATE Request packet is used by the client to notify the + server what dialects of the SMB2 Protocol the client understands. This is + only used if the client explicitly sets the Dialect to use to a version + less than 3.1.1. Dialect 3.1.1 added support for negotiate_context and + SMB3NegotiateRequest should be used to support that. + """ + COMMAND = Commands.SMB2_NEGOTIATE + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=36, + )), + ('dialect_count', IntField( + size=2, + default=lambda s: len(s['dialects'].get_value()), + )), + ('security_mode', FlagField( + size=2, + flag_type=SecurityMode + )), + ('reserved', IntField(size=2)), + ('capabilities', FlagField( + size=4, + flag_type=Capabilities, + )), + ('client_guid', UuidField()), + ('client_start_time', IntField(size=8)), + ('dialects', ListField( + size=lambda s: s['dialect_count'].get_value() * 2, + list_count=lambda s: s['dialect_count'].get_value(), + list_type=EnumField(size=2, enum_type=Dialects), + )), + ]) + + super(SMB2NegotiateRequest, self).__init__() + + +class SMB3NegotiateRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.3 SMB2 Negotiate Request + Like SMB2NegotiateRequest but with support for setting a list of + Negotiate Context values. This is used by default and is for Dialects 3.1.1 + or greater. + """ + COMMAND = Commands.SMB2_NEGOTIATE + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=36, + )), + ('dialect_count', IntField( + size=2, + default=lambda s: len(s['dialects'].get_value()), + )), + ('security_mode', FlagField( + size=2, + flag_type=SecurityMode, + )), + ('reserved', IntField(size=2)), + ('capabilities', FlagField( + size=4, + flag_type=Capabilities, + )), + ('client_guid', UuidField()), + ('negotiate_context_offset', IntField( + size=4, + default=lambda s: self._negotiate_context_offset_value(s), + )), + ('negotiate_context_count', IntField( + size=2, + default=lambda s: len(s['negotiate_context_list'].get_value()), + )), + ('reserved2', IntField(size=2)), + ('dialects', ListField( + size=lambda s: s['dialect_count'].get_value() * 2, + list_count=lambda s: s['dialect_count'].get_value(), + list_type=EnumField(size=2, enum_type=Dialects), + )), + ('padding', BytesField( + size=lambda s: self._padding_size(s), + default=lambda s: b"\x00" * self._padding_size(s), + )), + ('negotiate_context_list', ListField( + list_count=lambda s: s['negotiate_context_count'].get_value(), + unpack_func=lambda s, d: self._negotiate_context_list(s, d), + )), + ]) + super(SMB3NegotiateRequest, self).__init__() + + def _negotiate_context_offset_value(self, structure): + # The offset from the beginning of the SMB2 header to the first, 8-byte + # aligned, negotiate context + header_size = 64 + negotiate_size = structure['structure_size'].get_value() + dialect_size = len(structure['dialects']) + padding_size = self._padding_size(structure) + return header_size + negotiate_size + dialect_size + padding_size + + def _padding_size(self, structure): + # Padding between the end of the buffer value and the first Negotiate + # context value so that the first value is 8-byte aligned. Padding is + # 4 is there are no dialects specified + mod = (structure['dialect_count'].get_value() * 2) % 8 + return 0 if mod == 0 else mod + + def _negotiate_context_list(self, structure, data): + context_count = structure['negotiate_context_count'].get_value() + context_list = [] + for idx in range(0, context_count): + field, data = self._parse_negotiate_context_entry(data, idx) + context_list.append(field) + + return context_list + + def _parse_negotiate_context_entry(self, data, idx): + data_length = struct.unpack("= Dialects.SMB_2_1_0: + self.supports_file_leasing = \ + capabilities.has_flag(Capabilities.SMB2_GLOBAL_CAP_LEASING) + self.supports_multi_credit = \ + capabilities.has_flag(Capabilities.SMB2_GLOBAL_CAP_MTU) + + # SMB 3.x + if self.dialect >= Dialects.SMB_3_0_0: + self.supports_directory_leasing = capabilities.has_flag( + Capabilities.SMB2_GLOBAL_CAP_DIRECTORY_LEASING) + self.supports_multi_channel = capabilities.has_flag( + Capabilities.SMB2_GLOBAL_CAP_MULTI_CHANNEL) + + # TODO: SMB2_GLOBAL_CAP_PERSISTENT_HANDLES + self.supports_persistent_handles = False + self.supports_encryption = capabilities.has_flag( + Capabilities.SMB2_GLOBAL_CAP_ENCRYPTION) \ + and self.dialect < Dialects.SMB_3_1_1 + self.server_capabilities = capabilities + self.server_security_mode = \ + smb_response['security_mode'].get_value() + + # TODO: Check/add server to server_list in Client Page 203 + + # SMB 3.1 + if self.dialect >= Dialects.SMB_3_1_1: + for context in smb_response['negotiate_context_list']: + if context['context_type'].get_value() == \ + NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES: + cipher_id = context['data']['ciphers'][0] + self.cipher_id = Ciphers.get_cipher(cipher_id) + self.supports_encryption = self.cipher_id != 0 + else: + hash_id = context['data']['hash_algorithms'][0] + self.preauth_integrity_hash_id = \ + HashAlgorithms.get_algorithm(hash_id) + + def disconnect(self, close=True): + """ + Closes the connection as well as logs off any of the + Disconnects the TCP connection and shuts down the socket listener + running in a thread. + + :param close: Will close all sessions in the connection as well as the + tree connections of each session. + """ + if close: + for session in list(self.session_table.values()): + session.disconnect(True) + + log.info("Disconnecting transport connection") + self.transport.disconnect() + + def send(self, message, sid=None, tid=None, credit_request=None): + """ + Will send a message to the server that is passed in. The final + unencrypted header is returned to the function that called this. + + :param message: An SMB message structure to send + :param sid: A session_id that the message is sent for + :param tid: A tree_id object that the message is sent for + :param credit_request: Specifies extra credits to be requested with the + SMB header + :return: Request of the message that was sent + """ + header = self._generate_packet_header(message, sid, tid, + credit_request) + + # get the actual Session and TreeConnect object instead of the IDs + session = self.session_table.get(sid, None) if sid else None + + tree = None + if tid and session: + if tid not in session.tree_connect_table.keys(): + error_msg = "Cannot find Tree with the ID %d in the session " \ + "tree table" % tid + raise smbprotocol.exceptions.SMBException(error_msg) + tree = session.tree_connect_table[tid] + + if session and session.signing_required and session.signing_key: + self._sign(header, session) + request = Request(header) + self.outstanding_requests[header['message_id'].get_value()] = request + + send_data = header.pack() + if (session and session.encrypt_data) or (tree and tree.encrypt_data): + send_data = self._encrypt(send_data, session) + + self.transport.send(send_data) + + return request + + def send_compound(self, messages, sid, tid): + """ + Sends multiple messages within 1 TCP request, will fail if the size + of the total length exceeds the maximum of the transport max. + + :param messages: A list of messages to send to the server + :param sid: The session_id that the request is sent for + :param tid: A tree_id object that the message is sent for + :return: List for each request that was sent, each entry in + the list is in the same order of the message list that was passed + in + """ + send_data = b"" + session = self.session_table[sid] + tree = session.tree_connect_table[tid] + requests = [] + + total_requests = len(messages) + for i, message in enumerate(messages): + if i == total_requests - 1: + next_command = 0 + padding = b"" + else: + msg_length = 64 + len(message) + + # each compound message must start at the 8-byte boundary + mod = msg_length % 8 + padding_length = 8 - mod if mod > 0 else 0 + padding = b"\x00" * padding_length + next_command = msg_length + padding_length + + header = self._generate_packet_header(message, sid, tid, None) + header['next_command'] = next_command + if session.signing_required and session.signing_key: + self._sign(header, session, padding=padding) + send_data += header.pack() + padding + + request = Request(header) + requests.append(request) + self.outstanding_requests[header['message_id'].get_value()] = \ + request + + if session.encrypt_data or tree.encrypt_data: + send_data = self._encrypt(send_data, session) + self.transport.send(send_data) + + return requests + + def receive(self, request): + """ + Polls the message buffer of the TCP connection and waits until a valid + message is received based on the message_id passed in. + + :param request: The Request object to wait get the response for + :return: SMB2HeaderResponse of the received message + """ + # check if we have received a response + while not request.response: + self._flush_message_buffer() + + response = request.response + status = response['status'].get_value() + + if status == NtStatus.STATUS_PENDING: + request.response = None + + if status != NtStatus.STATUS_SUCCESS: + raise smbprotocol.exceptions.SMBResponseException(response, status) + + # now we have a retrieval request for the response, we can delete the + # request from the outstanding requests + message_id = request.message['message_id'].get_value() + del self.outstanding_requests[message_id] + + return response + + def _generate_packet_header(self, message, session_id, tree_id, + credit_request): + # when run in a thread or subprocess, getting the message id and + # adjusting the sequence window is important so we acquire a lock to + # ensure only one is run at a point in time + self.lock.acquire() + sequence_window_low = self.sequence_window['low'] + sequence_window_high = self.sequence_window['high'] + credit_charge = self._calculate_credit_charge(message) + credits_available = sequence_window_high - sequence_window_low + if credit_charge > credits_available: + error_msg = "Request requires %d credits but only %d credits " \ + "are available" \ + % (credit_charge, credits_available) + raise smbprotocol.exceptions.SMBException(error_msg) + + message_id = sequence_window_low + self.sequence_window['low'] += \ + credit_charge if credit_charge > 0 else 1 + self.lock.release() + + header = SMB2HeaderRequest() + header['credit_charge'] = credit_charge + header['command'] = message.COMMAND + header['credit_request'] = \ + credit_request if credit_request else credit_charge + header['message_id'] = message_id + header['process_id'] = os.getpid() + header['tree_id'] = tree_id if tree_id else 0 + header['session_id'] = session_id if session_id else 0 + + # we log before adding the data to avoid polluting the logs with + # too much info + log.info("Created SMB Packet Header for %s request" + % str(header['command'])) + log.debug(str(header)) + + header['data'] = message + + return header + + def _flush_message_buffer(self): + """ + Loops through the transport message_buffer until there are no messages + left in the queue. Each response is assigned to the Request object + based on the message_id which are then available in + self.outstanding_requests + """ + while True: + try: + message_bytes = self.transport.message_buffer.get(block=False) + except Empty: + # raises Empty if wait=False and there are no messages, in this + # case we have nothing to parse and so break from the loop + break + + # check if the message is encrypted and decrypt if necessary + if message_bytes[:4] == b"\xfdSMB": + message = SMB2TransformHeader() + message.unpack(message_bytes) + message_bytes = self._decrypt(message) + + # now retrieve message(s) from response + is_last = False + while not is_last: + next_command = struct.unpack(" 0 else 1 + + request.response = message + self.outstanding_requests[message_id] = request + + message_bytes = message_bytes[header_length:] + is_last = next_command == 0 + + def _sign(self, message, session, padding=None): + message['flags'].set_flag(Smb2Flags.SMB2_FLAGS_SIGNED) + signature = self._generate_signature(message, session, padding) + message['signature'] = signature + + def _verify(self, message, verify_session=False): + if message['message_id'].get_value() == 0xFFFFFFFFFFFFFFFF: + return + elif not message['flags'].has_flag(Smb2Flags.SMB2_FLAGS_SIGNED): + return + elif message['command'].get_value() == Commands.SMB2_SESSION_SETUP \ + and not verify_session: + return + + session_id = message['session_id'].get_value() + session = self.session_table.get(session_id, None) + if session is None: + error_msg = "Failed to find session %d for message verification" \ + % session_id + raise smbprotocol.exceptions.SMBException(error_msg) + expected = self._generate_signature(message, session) + actual = message['signature'].get_value() + if actual != expected: + error_msg = "Server message signature could not be verified: " \ + "%s != %s" % (actual, expected) + raise smbprotocol.exceptions.SMBException(error_msg) + + def _generate_signature(self, message, session, padding=None): + msg = copy.deepcopy(message) + msg['signature'] = b"\x00" * 16 + msg_data = msg.pack() + (padding if padding else b"") + + if self.dialect >= Dialects.SMB_3_0_0: + # TODO: work out when to get channel.signing_key + signing_key = session.signing_key + + c = cmac.CMAC(algorithms.AES(signing_key), + backend=default_backend()) + c.update(msg_data) + signature = c.finalize() + else: + signing_key = session.signing_key + hmac_algo = hmac.new(signing_key, msg=msg_data, + digestmod=hashlib.sha256) + signature = hmac_algo.digest()[:16] + + return signature + + def _encrypt(self, data, session): + header = SMB2TransformHeader() + header['original_message_size'] = len(data) + header['session_id'] = session.session_id + + encryption_key = session.encryption_key + if self.dialect >= Dialects.SMB_3_1_1: + cipher = self.cipher_id + else: + cipher = Ciphers.get_cipher(Ciphers.AES_128_CCM) + if cipher == aead.AESGCM: + nonce = os.urandom(12) + header['nonce'] = nonce + (b"\x00" * 4) + else: + nonce = os.urandom(11) + header['nonce'] = nonce + (b"\x00" * 5) + + cipher_text = cipher(encryption_key).encrypt(nonce, data, + header.pack()[20:]) + signature = cipher_text[-16:] + enc_message = cipher_text[:-16] + + header['signature'] = signature + header['data'] = enc_message + + return header + + def _decrypt(self, message): + if message['flags'].get_value() != 0x0001: + error_msg = "Expecting flag of 0x0001 but got %s in the SMB " \ + "Transform Header Response"\ + % format(message['flags'].get_value(), 'x') + raise smbprotocol.exceptions.SMBException(error_msg) + + session_id = message['session_id'].get_value() + session = self.session_table.get(session_id, None) + if session is None: + error_msg = "Failed to find valid session %s for message " \ + "decryption" % session_id + raise smbprotocol.exceptions.SMBException(error_msg) + + if self.dialect >= Dialects.SMB_3_1_1: + cipher = self.cipher_id + else: + cipher = Ciphers.get_cipher(Ciphers.AES_128_CCM) + + nonce_length = 12 if cipher == aead.AESGCM else 11 + nonce = message['nonce'].get_value()[:nonce_length] + + signature = message['signature'].get_value() + enc_message = message['data'].get_value() + signature + + c = cipher(session.decryption_key) + dec_message = c.decrypt(nonce, enc_message, message.pack()[20:52]) + return dec_message + + def _send_smb2_negotiate(self, dialect): + self.salt = os.urandom(32) + + if dialect is None: + neg_req = SMB3NegotiateRequest() + self.negotiated_dialects = [ + Dialects.SMB_2_0_2, + Dialects.SMB_2_1_0, + Dialects.SMB_3_0_0, + Dialects.SMB_3_0_2, + Dialects.SMB_3_1_1 + ] + highest_dialect = Dialects.SMB_3_1_1 + else: + if dialect >= Dialects.SMB_3_1_1: + neg_req = SMB3NegotiateRequest() + else: + neg_req = SMB2NegotiateRequest() + self.negotiated_dialects = [ + dialect + ] + highest_dialect = dialect + neg_req['dialects'] = self.negotiated_dialects + log.info("Negotiating with SMB2 protocol with highest client dialect " + "of: %s" % [dialect for dialect, v in vars(Dialects).items() + if v == highest_dialect][0]) + + neg_req['security_mode'] = self.client_security_mode + + if highest_dialect >= Dialects.SMB_2_1_0: + log.debug("Adding client guid %s to negotiate request" + % self.client_guid) + neg_req['client_guid'] = self.client_guid + + if highest_dialect >= Dialects.SMB_3_0_0: + log.debug("Adding client capabilities %d to negotiate request" + % self.client_capabilities) + neg_req['capabilities'] = self.client_capabilities + + if highest_dialect >= Dialects.SMB_3_1_1: + int_cap = SMB2NegotiateContextRequest() + int_cap['context_type'] = \ + NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES + int_cap['data'] = SMB2PreauthIntegrityCapabilities() + int_cap['data']['hash_algorithms'] = [ + HashAlgorithms.SHA_512 + ] + int_cap['data']['salt'] = self.salt + log.debug("Adding preauth integrity capabilities of hash SHA512 " + "and salt %s to negotiate request" % self.salt) + + enc_cap = SMB2NegotiateContextRequest() + enc_cap['context_type'] = \ + NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES + enc_cap['data'] = SMB2EncryptionCapabilities() + supported_ciphers = Ciphers.get_supported_ciphers() + enc_cap['data']['ciphers'] = supported_ciphers + # remove extra padding for last list entry + enc_cap['padding'].size = 0 + enc_cap['padding'] = b"" + log.debug("Adding encryption capabilities of AES128 GCM and " + "AES128 CCM to negotiate request") + + neg_req['negotiate_context_list'] = [ + int_cap, + enc_cap + ] + + log.info("Sending SMB2 Negotiate message") + log.debug(str(neg_req)) + request = self.send(neg_req) + self.preauth_integrity_hash_value.append(request.message) + + response = self.receive(request) + log.info("Receiving SMB2 Negotiate response") + log.debug(str(response)) + self.preauth_integrity_hash_value.append(response) + + smb_response = SMB2NegotiateResponse() + smb_response.unpack(response['data'].get_value()) + + return smb_response + + def _calculate_credit_charge(self, message): + """ + Calculates the credit charge for a request based on the command. If + connection.supports_multi_credit is not True then the credit charge + isn't valid so it returns 0. + + The credit charge is the number of credits that are required for + sending/receiving data over 64 kilobytes, in the existing messages only + the Read, Write, Query Directory or IOCTL commands will end in this + scenario and each require their own calculation to get the proper + value. The generic formula for calculating the credit charge is + + https://msdn.microsoft.com/en-us/library/dn529312.aspx + (max(SendPayloadSize, Expected ResponsePayloadSize) - 1) / 65536 + 1 + + :param message: The message being sent + :return: The credit charge to set on the header + """ + credit_size = 65536 + + if not self.supports_multi_credit: + credit_charge = 0 + elif message.COMMAND == Commands.SMB2_READ: + max_size = message['length'].get_value() + \ + message['read_channel_info_length'].get_value() - 1 + credit_charge = math.ceil(max_size / credit_size) + elif message.COMMAND == Commands.SMB2_WRITE: + max_size = message['length'].get_value() + \ + message['write_channel_info_length'].get_value() - 1 + credit_charge = math.ceil(max_size / credit_size) + elif message.COMMAND == Commands.SMB2_IOCTL: + max_in_size = len(message['buffer']) + max_out_size = message['max_output_response'].get_value() + max_size = max(max_in_size, max_out_size) - 1 + credit_charge = math.ceil(max_size / credit_size) + elif message.COMMAND == Commands.SMB2_QUERY_DIRECTORY: + max_in_size = len(message['buffer']) + max_out_size = message['output_buffer_length'].get_value() + max_size = max(max_in_size, max_out_size) - 1 + credit_charge = math.ceil(max_size / credit_size) + else: + credit_charge = 1 + + # python 2 returns a float where we need an integer + return int(credit_charge) + + +class Request(object): + + def __init__(self, message): + """ + [MS-SMB2] v53.0 2017-09-15 + + 3.2.1.7 Per Pending Request + For each request that was sent to the server and is await a response + :param message: The message to be sent in the request + """ + self.cancel_id = os.urandom(8) + self.async_id = os.urandom(8) + self.message = message + self.timestamp = datetime.now() + + # not in SMB spec + # Used to contain the corresponding response from the server as the + # receiving in done in parallel + self.response = None diff --git a/smbprotocol/create_contexts.py b/smbprotocol/create_contexts.py new file mode 100644 index 00000000..1594f3cc --- /dev/null +++ b/smbprotocol/create_contexts.py @@ -0,0 +1,830 @@ +import smbprotocol.connection +import smbprotocol.open +from smbprotocol.structure import BoolField, BytesField, DateTimeField, \ + EnumField, FlagField, IntField, Structure, UuidField + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + + +class CreateContextName(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.13.2 SMB2_CREATE_CONTEXT Request Values + Valid names for the name to set on a SMB2_CREATE_CONTEXT Request entry + """ + SMB2_CREATE_EA_BUFFER = b"\x45\x78\x74\x41" + + # note: the structures for this are located in security_descriptor.py + SMB2_CREATE_SD_BUFFER = b"\x53\x65\x63\x44" + SMB2_CREATE_DURABLE_HANDLE_REQUEST = b"\x44\x48\x6e\x51" + SMB2_CREATE_DURABLE_HANDLE_RECONNECT = b"\x44\x48\x6e\x43" + SMB2_CREATE_ALLOCATION_SIZE = b"\x41\x6c\x53\x69" + SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST = b"\x4d\x78\x41\x63" + SMB2_CREATE_TIMEWARP_TOKEN = b"\x54\x57\x72\x70" + SMB2_CREATE_QUERY_ON_DISK_ID = b"\x51\x46\x69\x64" + SMB2_CREATE_REQUEST_LEASE = b"\x52\x71\x4c\x73" + SMB2_CREATE_REQUEST_LEASE_V2 = b"\x52\x71\x4c\x73" + SMB2_CREATE_DURABLE_HANDLE_REQUEST_V2 = b"\x44\x48\x32\x51" + SMB2_CREATE_DURABLE_HANDLE_RECONNECT_V2 = b"\x44\x48\x32\x43" + SMB2_CREATE_APP_INSTANCE_ID = b"\x45\xBC\xA6\x6A\xEF\xA7\xF7\x4A" \ + b"\x90\x08\xFA\x46\x2E\x14\x4D\x74" + SMB2_CREATE_APP_INSTANCE_VERSION = b"\xB9\x82\xD0\xB7\x3B\x56\x07\x4F" \ + b"\xA0\x7B\x52\x4A\x81\x16\xA0\x10" + SVHDX_OPEN_DEVICE_CONTEXT = b"\x9C\xCB\xCF\x9E\x04\xC1\xE6\x43" \ + b"\x98\x0E\x15\x8D\xA1\xF6\xEC\x83" + + @staticmethod + def get_response_structure(name): + """ + Returns the response structure for a know list of create context + responses. + + :param name: The constant value above + :return: The response structure or None if unknown + """ + return { + CreateContextName.SMB2_CREATE_DURABLE_HANDLE_REQUEST: + SMB2CreateDurableHandleResponse(), + CreateContextName.SMB2_CREATE_DURABLE_HANDLE_RECONNECT: + SMB2CreateDurableHandleReconnect(), + CreateContextName.SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST: + SMB2CreateQueryMaximalAccessResponse(), + CreateContextName.SMB2_CREATE_REQUEST_LEASE: + SMB2CreateResponseLease(), + CreateContextName.SMB2_CREATE_QUERY_ON_DISK_ID: + SMB2CreateQueryOnDiskIDResponse(), + CreateContextName.SMB2_CREATE_REQUEST_LEASE_V2: + SMB2CreateResponseLeaseV2(), + CreateContextName.SMB2_CREATE_DURABLE_HANDLE_REQUEST_V2: + SMB2CreateDurableHandleResponseV2(), + CreateContextName.SMB2_CREATE_DURABLE_HANDLE_RECONNECT_V2: + SMB2CreateDurableHandleReconnectV2, + CreateContextName.SMB2_CREATE_APP_INSTANCE_ID: + SMB2CreateAppInstanceId(), + CreateContextName.SMB2_CREATE_APP_INSTANCE_VERSION: + SMB2CreateAppInstanceVersion() + + }.get(name, None) + + +class EAFlags(object): + """ + [MS-FSCC] + + 2.4.15 FileFullEaInformation Flags + Specifies the flag used when setting extended attributes. + """ + NONE = 0x0000000 + FILE_NEED_EA = 0x00000080 + + +class LeaseState(object): + """ + [MS-SMB2] + + 2.2.13.2.8 SMB2_CREATE_REQUEST_LEASE LeaseState + The requested lease state, field is constructed with a combination of the + following values. + """ + SMB2_LEASE_NONE = 0x00 + SMB2_LEASE_READ_CACHING = 0x01 + SMB2_LEASE_HANDLE_CACHING = 0x02 + SMB2_LEASE_WRITE_CACHING = 0x04 + + +class LeaseRequestFlags(object): + """ + [MS-SMB2] + + 2.2.13.2.10 SMB2_CREATE_REQUEST_LEASE_V2 + The flags to use on an SMB2CreateRequestLeaseV2 packet. + """ + SMB2_LEASE_FLAG_PARENT_LEASE_KEY_SET = 0x00000004 + + +class LeaseResponseFlags(object): + """ + [MS-SMB2] + + 2.2.14.2.10 SMB2_CREATE_RESPONSE_LEASE + """ + SMB2_LEASE_FLAG_BREAK_IN_PROGRESS = 0x00000002 + SMB2_LEASE_FLAG_PARENT_LEASE_KEY_SET = 0x00000004 # V2 Response + + +class DurableHandleFlags(object): + """ + [MS-SMB2] + + 2.2.13.2.11 SMB2_CREATE_DURABLE_HANDLE_REQUEST_V2 + Flags used on an SMB2CreateDurableHandleRequestV2 packet. + """ + SMB2_DHANDLE_FLAG_PERSISTENT = 0x00000002 + + +class SVHDXOriginatorFlags(object): + """ + [MS-RSVD] 2.2.4.12 SVHDX_OPEN_DEVICE_CONTEXT OriginatorFlags + Used to indicate which component has originated or issued the operations. + """ + SVHDX_ORIGINATOR_PVHDPARSER = 0x00000001 + SVHDX_ORIGINATOR_VHDMP = 0x00000004 + + +class SMB2CreateContextRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.13.2 SMB2_CREATE_CONTEXT Request Values + Structure used in the SMB2 CREATE Request and SMB2 CREATE Response to + encode additional flags and attributes + """ + + def __init__(self): + self.fields = OrderedDict([ + ('next', IntField(size=4)), + ('name_offset', IntField( + size=2, + default=16 + )), + ('name_length', IntField( + size=2, + default=lambda s: len(s['buffer_name']) + )), + ('reserved', IntField(size=2)), + ('data_offset', IntField( + size=2, + default=lambda s: self._buffer_data_offset(s) + )), + ('data_length', IntField( + size=4, + default=lambda s: len(s['buffer_data']) + )), + ('buffer_name', BytesField( + size=lambda s: s['name_length'].get_value() + )), + ('padding', BytesField( + size=lambda s: self._padding_size(s), + default=lambda s: b"\x00" * self._padding_size(s) + )), + ('buffer_data', BytesField( + size=lambda s: s['data_length'].get_value() + )), + # not actually a field but each list entry must start at the 8 byte + # alignment + ('padding2', BytesField( + size=lambda s: self._padding2_size(s), + default=lambda s: b"\x00" * self._padding2_size(s) + )) + ]) + super(SMB2CreateContextRequest, self).__init__() + + def _buffer_data_offset(self, structure): + if structure['data_length'].get_value() == 0: + return 0 + else: + return structure['name_offset'].get_value() + \ + len(structure['buffer_name']) + len(structure['padding']) + + def _padding_size(self, structure): + if structure['data_length'].get_value() == 0: + return 0 + + buffer_name_len = structure['name_length'].get_value() + mod = buffer_name_len % 8 + return mod if mod == 0 else 8 - mod + + def _padding2_size(self, structure): + if structure['next'].get_value() == 0: + return 0 + + data_length = len(structure['buffer_name']) + \ + len(structure['padding']) + len(structure['buffer_data']) + mod = data_length % 8 + return mod if mod == 0 else 8 - mod + + def get_context_data(self): + """ + Get the buffer_data value of a context response and try to convert it + to the relevant structure based on the buffer_name used. If it is an + unknown structure then the raw bytes are returned. + + :return: relevant Structure of buffer_data or bytes if unknown name + """ + buffer_name = self['buffer_name'].get_value() + structure = CreateContextName.get_response_structure(buffer_name) + if structure: + structure.unpack(self['buffer_data'].get_value()) + return structure + else: + # unknown structure, just return the raw bytes + return self['buffer_data'].get_value() + + @staticmethod + def pack_multiple(messages): + """ + Converts a list of SMB2CreateContextRequest structures and packs them + as a bytes object used when setting to the SMB2CreateRequest + buffer_contexts field. This should be used as it would calculate the + correct next field value for each context entry. + + :param messages: List of SMB2CreateContextRequest structures + :return: bytes object that is set on the SMB2CreateRequest + buffer_contexts field. + """ + data = b"" + msg_count = len(messages) + for i, msg in enumerate(messages): + if i == msg_count - 1: + msg['next'] = 0 + else: + # because the end padding2 val won't be populated if the entry + # offset is 0, we set to 1 so the len calc is correct + msg['next'] = 1 + msg['next'] = len(msg) + + data += msg.pack() + return data + + +class SMB2CreateEABuffer(Structure): + """ + [MS-SMB2] 2.2.13.2.1 SMB2_CREATE_EA_BUFFER + [MS-FSCC] 2.4.15 FileFullEaInformation + + Used to apply extended attributes as part of creating a new file. + """ + + def __init__(self): + self.fields = OrderedDict([ + # 0 if no more entries, otherwise offset after ea_value + ('next_entry_offset', IntField(size=4)), + ('flags', FlagField( + size=1, + flag_type=EAFlags + )), + ('ea_name_length', IntField( + size=1, + default=lambda s: len(s['ea_name']) - 1 # minus \x00 + )), + ('ea_value_length', IntField( + size=2, + default=lambda s: len(s['ea_value']) + )), + # ea_name is ASCII byte encoded and needs a null terminator '\x00' + ('ea_name', BytesField( + size=lambda s: s['ea_name_length'].get_value() + 1 + )), + ('ea_value', BytesField( + size=lambda s: s['ea_value_length'].get_value() + )), + # not actually a field but each list entry must start at the 4 byte + # alignment + ('padding', BytesField( + size=lambda s: self._padding_size(s), + default=lambda s: b"\x00" * self._padding_size(s) + )) + ]) + super(SMB2CreateEABuffer, self).__init__() + + def _padding_size(self, structure): + if structure['next_entry_offset'].get_value() == 0: + return 0 + + data_length = len(structure['ea_name']) + len(structure['ea_value']) + mod = data_length % 4 + return mod if mod == 0 else 4 - mod + + @staticmethod + def pack_multiple(messages): + """ + Converts a list of SMB2CreateEABuffer structures and packs them as a + bytes object used when setting to the SMB2CreateContextRequest + buffer_data field. This should be used as it would calculate the + correct next_entry_offset field value for each buffer entry. + + :param messages: List of SMB2CreateEABuffer structures + :return: bytes object that is set on the SMB2CreateContextRequest + buffer_data field. + """ + data = b"" + msg_count = len(messages) + for i, msg in enumerate(messages): + if i == msg_count - 1: + msg['next_entry_offset'] = 0 + else: + # because the end padding val won't be populated if the entry + # offset is 0, we set to 1 so the len calc is correct + msg['next_entry_offset'] = 1 + msg['next_entry_offset'] = len(msg) + data += msg.pack() + + return data + + +class SMB2CreateDurableHandleRequest(Structure): + """ + [MS-SMB2] 2.2.13.2.3 SMB2_CREATE_DURABLE_HANDLE_REQUEST + + Used by the client to mark the open as a durable open. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('durable_request', BytesField(size=16, default=b"\x00" * 16)) + ]) + super(SMB2CreateDurableHandleRequest, self).__init__() + + +class SMB2CreateDurableHandleResponse(Structure): + """ + [MS-SMB2] 2.2.14.2.3 SMB2_CREATE_DURABLE_HANDLE_RESPONSE + + Sent by the server in response to an SMB2CreateDurableHandleRequest packet. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('reserved', IntField(size=8)) + ]) + super(SMB2CreateDurableHandleResponse, self).__init__() + + +class SMB2CreateDurableHandleReconnect(Structure): + """ + [MS-SMB2] 2.2.13.2.4 SMB2_CREATE_DURABLE_HANDLE_RECONNECT + + Used by the client when attempting to reestablish a durable open + """ + + def __init__(self): + self.fields = OrderedDict([ + ('data', BytesField(size=16)) + ]) + super(SMB2CreateDurableHandleReconnect, self).__init__() + + +class SMB2CreateQueryMaximalAccessRequest(Structure): + """ + [MS-SMB2] 2.2.13.2.5 SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST + + Used by the client to retrieve maximal access information as part of + processing the open. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('timestamp', DateTimeField()) + ]) + super(SMB2CreateQueryMaximalAccessRequest, self).__init__() + + +class SMB2CreateQueryMaximalAccessResponse(Structure): + """ + [MS-SMB2] 2.2.14.2.5 SMB2_CREATE_QUERY_MAXIMAL_ACCESS_RESPONSE + + Used by the server in response to an SMB2CreateQueryMaximalAccessRequest + packet. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('query_status', EnumField( + size=4, + enum_type=smbprotocol.connection.NtStatus, + enum_strict=False + )), + # either FilePipePrinterAccessMask or DirectoryAccessMask + ('maximal_access', IntField(size=4)) + ]) + super(SMB2CreateQueryMaximalAccessResponse, self).__init__() + + +class SMB2CreateAllocationSize(Structure): + """ + [MS-SMB2] 2.2.13.2.6 SMB2_CREATE_ALLOCATION_SIZE + + Used by the client to set the allocation size of a file that is being + newly created or overwritten. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('allocation_size', IntField(size=8)) + ]) + super(SMB2CreateAllocationSize, self).__init__() + + +class SMB2CreateTimewarpToken(Structure): + """ + [MS-SMB2] 2.2.13.2.7 SMB2_CREATE_TIMEWARP_TOKEN + + Used by the client when requesting the server to open a version of the file + at a previous point in time. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('timestamp', DateTimeField()) + ]) + super(SMB2CreateTimewarpToken, self).__init__() + + +class SMB2CreateRequestLease(Structure): + """ + [MS-SMB2] 2.2.13.2.8 SMB2_CREATE_REQUEST_LEASE + + Used by the cliet when requesting the server to return a lease. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('lease_key', BytesField(size=16)), + ('lease_state', FlagField( + size=4, + flag_type=LeaseState + )), + ('lease_flags', IntField(size=4)), + ('lease_duration', IntField(size=8)) + ]) + super(SMB2CreateRequestLease, self).__init__() + + +class SMB2CreateResponseLease(Structure): + """ + [MS-SMB2] 2.2.14.2.10 SMB2_CREATE_RESPONSE_LEASE + + Sent by the server in response to an SMB2CreateRequestLease + """ + + def __init__(self): + self.fields = OrderedDict([ + ('lease_key', BytesField(size=16)), + ('lease_state', FlagField( + size=4, + flag_type=LeaseState + )), + ('lease_flags', FlagField( + size=4, + flag_type=LeaseResponseFlags + )), + ('lease_duration', IntField(size=8)) + ]) + super(SMB2CreateResponseLease, self).__init__() + + +class SMB2CreateQueryOnDiskIDResponse(Structure): + """ + [MS-SMB2] 2.2.14.2.9 SMB2_CREATE_QUERY_ON_DISK_ID + + Sent by the server in response to an SMB2CreateQueryOnDiskIDRequest packet. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('disk_file_id', IntField(size=8)), + ('volume_id', IntField(size=8)), + ('reserved', BytesField( + size=16, + default=b"\x00" * 16 + )) + ]) + super(SMB2CreateQueryOnDiskIDResponse, self).__init__() + + +class SMB2CreateRequestLeaseV2(Structure): + """ + [MS-SMB2] 2.2.13.2.10 SMB2_CREATE_REQUEST_LEASE_V2 + + Used when the client is requesting the server to return a lease on a file + or directory. + Valid for the SMB 3.x family only + """ + + def __init__(self): + self.fields = OrderedDict([ + ('lease_key', BytesField(size=16)), + ('lease_state', FlagField( + size=4, + flag_type=LeaseState + )), + ('lease_flags', FlagField( + size=4, + flag_type=LeaseRequestFlags + )), + ('lease_duration', IntField(size=8)), + ('parent_lease_key', BytesField(size=16)), + ('epoch', BytesField(size=16)), + ('reserved', IntField(size=2)) + ]) + super(SMB2CreateRequestLeaseV2, self).__init__() + + +class SMB2CreateResponseLeaseV2(Structure): + """ + [MS-SMB2] 2.2.14.2.11 SMB2_CREATE_RESPONSE_LEASE_V2 + + Sent by the server in response to an SMB2CreateRequestLeaseV2 packet. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('lease_key', BytesField(size=16)), + ('lease_state', FlagField( + size=4, + flag_type=LeaseState + )), + ('flags', FlagField( + size=4, + flag_type=LeaseResponseFlags + )), + ('lease_duration', IntField(size=8)), + ('parent_lease_key', BytesField(size=16)), + ('epoch', IntField(size=2)), + ('reserved', IntField(size=2)) + ]) + super(SMB2CreateResponseLeaseV2, self).__init__() + + +class SMB2CreateDurableHandleRequestV2(Structure): + """ + [MS-SMB2] 2.2.13.2.11 SMB2_CREATE_DURABLE_HANDLE_REQUEST_V2 + + Used by the client to request the server mark the open as durable or + persistent. + Valid for the SMB 3.x family only + """ + + def __init__(self): + self.fields = OrderedDict([ + # timeout in milliseconds + ('timeout', IntField(size=4)), + ('flags', FlagField( + size=4, + flag_type=DurableHandleFlags + )), + ('reserved', IntField(size=8)), + ('create_guid', UuidField(size=16)) + ]) + super(SMB2CreateDurableHandleRequestV2, self).__init__() + + +class SMB2CreateDurableHandleReconnectV2(Structure): + """ + [MS-SMB2] 2.2.13.2.12 SMB2_CREATE_DURABLE_HANDLE_RECONNECT_V2 + + Used by the client when reestablishing a durable open. + Valid for the SMB 3.x family only + """ + + def __init__(self): + self.fields = OrderedDict([ + ('file_id', BytesField(size=16)), + ('create_guid', UuidField(size=16)), + ('flags', FlagField( + size=4, + flag_type=DurableHandleFlags + )) + ]) + super(SMB2CreateDurableHandleReconnectV2, self).__init__() + + +class SMB2CreateDurableHandleResponseV2(Structure): + """ + [MS-SMB2] 2.2.14.2.12 SMB2_CREATE_DURABLE_HANDLE_RESPONSE_V2 + + Sent by the server in response to an SMB2CreateDurableHandleRequestV2 + packet. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('timeout', IntField(size=4)), + ('flags', FlagField( + size=4, + flag_type=DurableHandleFlags + )) + ]) + super(SMB2CreateDurableHandleResponseV2, self).__init__() + + +class SMB2CreateAppInstanceId(Structure): + """ + [MS-SMB2] 2.2.13.2.13 SMB2_CREATE_APP_INSTANCE_ID + + Used by the client when supplying an identifier provided by an application. + Valid for the SMB 3.x family and should also have an durable handle on the + create request. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=20 + )), + ('reserved', IntField(size=2)), + ('app_instance_id', BytesField(size=16)) + ]) + super(SMB2CreateAppInstanceId, self).__init__() + + +class SMB2SVHDXOpenDeviceContextRequest(Structure): + """ + [MS-SMB2] 2.2.13.2.14 SVHDX_OPEN_DEVICE_CONTEXT + [MS-RSVD] 2.2.4.12 SVHDX_OPEN_DEVICE_CONTEXT + + Used to open the shared virtual disk file. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('version', IntField( + size=4, + default=1 + )), + ('has_initiator_id', BoolField( + size=1, + default=lambda s: len(s['initiator_host_name']) > 0 + )), + ('reserved', BytesField( + size=3, + default=b"\x00\x00\x00" + )), + ('initiator_id', UuidField(size=16)), + ('originator_flags', EnumField( + size=4, + enum_type=SVHDXOriginatorFlags + )), + ('open_request_id', IntField(size=8)), + ('initiator_host_name_length', IntField( + size=2, + default=lambda s: len(s['initiator_host_name']) + )), + # utf-16-le encoded string + ('initiator_host_name', BytesField( + size=lambda s: s['initiator_host_name_length'].get_value() + )) + ]) + super(SMB2SVHDXOpenDeviceContextRequest, self).__init__() + + +class SMB2SVHDXOpenDeviceContextResponse(Structure): + """ + [MS-SMB2] 2.2.14.2.14 SVHDX_OPEN_DEVICE_CONTEXT_RESPONSE + [MS-RSVD] 2.2.4.31 SVHDX_OPEN_DEVICE_CONTEXT_RESPONSE + + The response packet sent by the server in response to an + SMB2VHDXOpenDeviceContextRequest + """ + + def __init__(self): + self.fields = OrderedDict([ + ('version', IntField( + size=4, + default=1 + )), + ('has_initiator_id', BoolField( + size=1, + default=lambda s: len(s['initiator_host_name']) > 0 + )), + ('reserved', BytesField( + size=3, + default=b"\x00\x00\x00" + )), + ('initiator_id', UuidField(size=16)), + ('flags', IntField(size=4)), + ('originator_flags', EnumField( + size=4, + enum_type=SVHDXOriginatorFlags + )), + ('open_request_id', IntField(size=8)), + ('initiator_host_name_length', IntField( + size=2, + default=lambda s: len(s['initiator_host_name']) + )), + # utf-16-le encoded string + ('initiator_host_name', BytesField( + size=lambda s: s['initiator_host_name_length'].get_value() + )) + ]) + super(SMB2SVHDXOpenDeviceContextResponse, self).__init__() + + +class SMB2SVHDXOpenDeviceContextV2Request(Structure): + """ + [MS-SMB2] 2.2.13.2.14 SVHDX_OPEN_DEVICE_CONTEXT + [MS-RSVD] 2.2.4.32 SVHDX_OPEN_DEVICE_CONTEXT_V2 + + Used to open the shared virtual disk file on the RSVD Protocol version 2 + """ + + def __init__(self): + self.fields = OrderedDict([ + ('version', IntField( + size=4, + default=2 + )), + ('has_initiator_id', BoolField( + size=1, + default=lambda s: len(s['initiator_host_name']) > 0 + )), + ('reserved', BytesField( + size=3, + default=b"\x00\x00\x00" + )), + ('initiator_id', UuidField(size=16)), + ('originator_flags', EnumField( + size=4, + enum_type=SVHDXOriginatorFlags + )), + ('open_request_id', IntField(size=8)), + ('initiator_host_name_length', IntField( + size=2, + default=lambda s: len(s['initiator_host_name']) + )), + # utf-16-le encoded string + ('initiator_host_name', BytesField( + size=lambda s: s['initiator_host_name_length'].get_value() + )), + ('virtual_disk_properties_initialized', IntField(size=4)), + ('server_service_version', IntField(size=4)), + ('virtual_sector_size', IntField(size=4)), + ('physical_sector_size', IntField(size=4)), + ('virtual_size', IntField(size=8)) + ]) + super(SMB2SVHDXOpenDeviceContextV2Request, self).__init__() + + +class SMB2SVHDXOpenDeviceContextV2Response(Structure): + """ + [MS-SMB2] 2.2.14.2.14 SVHDX_OPEN_DEVICE_CONTEXT_RESPONSE + [MS-RSVD] 2.2.4.32 SVHDX_OPEN_DEVICE_CONTEXT_V2_RESPONSE + + The response packet sent by the server in response to an + SMB2VHDXOpenDeviceContextV2Request + """ + + def __init__(self): + self.fields = OrderedDict([ + ('version', IntField( + size=4, + default=2 + )), + ('has_initiator_id', BoolField( + size=1, + default=lambda s: len(s['initiator_host_name']) > 0 + )), + ('reserved', BytesField( + size=3, + default=b"\x00\x00\x00" + )), + ('initiator_id', UuidField(size=16)), + ('flags', IntField(size=4)), + ('originator_flags', EnumField( + size=4, + enum_type=SVHDXOriginatorFlags + )), + ('open_request_id', IntField(size=8)), + ('initiator_host_name_length', IntField( + size=2, + default=lambda s: len(s['initiator_host_name']) + )), + # utf-16-le encoded string + ('initiator_host_name', BytesField( + size=lambda s: s['initiator_host_name_length'].get_value() + )), + ('virtual_disk_properties_initialized', IntField(size=4)), + ('server_service_version', IntField(size=4)), + ('virtual_sector_size', IntField(size=4)), + ('physical_sector_size', IntField(size=4)), + ('virtual_size', IntField(size=8)) + ]) + super(SMB2SVHDXOpenDeviceContextV2Response, self).__init__() + + +class SMB2CreateAppInstanceVersion(Structure): + """ + [MS-SMB2] 2.2.13.2.15 SMB2_CREATE_APP_INSTANCE_VERSION + + Used when the client is supplying a version for the app instance identifier + provided by an application. + Valid for the SMB 3.1.1+ family + """ + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=24 + )), + ('reserved', IntField(size=2)), + ('padding', IntField(size=4)), + ('app_instance_version_high', IntField(size=8)), + ('app_instance_version_low', IntField(size=8)) + ]) + super(SMB2CreateAppInstanceVersion, self).__init__() diff --git a/smbprotocol/exceptions.py b/smbprotocol/exceptions.py new file mode 100644 index 00000000..06c957b7 --- /dev/null +++ b/smbprotocol/exceptions.py @@ -0,0 +1,472 @@ +import binascii +import socket + +import smbprotocol.connection +from smbprotocol.structure import BytesField, EnumField, IntField, ListField, \ + Structure, StructureField + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + + +class SMBException(Exception): + # Generic SMB Exception with a message + pass + + +class SMBAuthenticationError(SMBException): + # Used for authentication specific errors + pass + + +class SMBUnsupportedFeature(SMBException): + + @property + def negotiated_dialect(self): + return self.args[0] + + @property + def required_dialect(self): + return self.args[1] + + @property + def feature_name(self): + return self.args[2] + + @property + def requires_newer(self): + if len(self.args) > 3: + return self.args[3] + else: + return None + + @property + def message(self): + if self.requires_newer is None: + msg_suffix = "" + elif self.requires_newer: + msg_suffix = " or newer" + else: + msg_suffix = " or older" + + required_dialect = self._get_dialect_name(self.required_dialect) + negotiated_dialect = self._get_dialect_name(self.negotiated_dialect) + + msg = "%s is not available on the negotiated dialect %s, " \ + "requires dialect %s%s"\ + % (self.feature_name, negotiated_dialect, required_dialect, + msg_suffix) + return msg + + def __str__(self): + return self.message + + def _get_dialect_name(self, dialect): + dialect_field = EnumField( + enum_type=smbprotocol.connection.Dialects, + enum_strict=False, + size=2) + dialect_field.set_value(dialect) + return str(dialect_field) + + +class SMBResponseException(SMBException): + + @property + def header(self): + # the full message that was returned by the server + return self.args[0] + + @property + def status(self): + # the raw int status value, used by method that catch this exception + # for control flow + return self.args[1] + + @property + def error_details(self): + # list of error_details returned by the server, currently used in + # the SMB 3.1.1 error response for certain situations + error = SMB2ErrorResponse() + error.unpack(self.header['data'].get_value()) + + error_details = [] + for raw_error_data in error['error_data'].get_value(): + nt_status = smbprotocol.connection.NtStatus + error_id = raw_error_data['error_id'].get_value() + raw_data = raw_error_data['error_context_data'].get_value() + if self.status == nt_status.STATUS_STOPPED_ON_SYMLINK: + error_data = SMB2SymbolicLinkErrorResponse() + error_data.unpack(raw_data) + elif self.status == nt_status.STATUS_BAD_NETWORK_NAME and \ + error_id == ErrorContextId.SMB2_ERROR_ID_SHARE_REDIRECT: + error_data = SMB2ShareRedirectErrorContext() + error_data.unpack(raw_data) + else: + # unknown context data so we just set it the raw bytes + error_data = raw_data + error_details.append(error_data) + + return error_details + + @property + def message(self): + error_details_msg = "" + for error_detail in self.error_details: + if isinstance(error_detail, SMB2SymbolicLinkErrorResponse): + detail_msg = self._get_symlink_error_detail_msg(error_detail) + elif isinstance(error_detail, SMB2ShareRedirectErrorContext): + detail_msg = self._get_share_redirect_detail_msg(error_detail) + else: + # unknown error details in response, output raw bytes + detail_msg = "Raw: %s"\ + % binascii.hexlify(error_detail).decode('utf-8') + + # the first details message is set differently + if error_details_msg == "": + error_details_msg = "%s - %s" % (error_details_msg, detail_msg) + else: + error_details_msg = "%s, %s" % (error_details_msg, detail_msg) + + status_hex = format(self.status, 'x') + error_message = "%s: 0x%s%s" % (str(self.header['status']), + status_hex, error_details_msg) + return "Received unexpected status from the server: %s" % error_message + + def __str__(self): + return self.message + + def _get_share_redirect_detail_msg(self, error_detail): + ip_addresses = [] + for ip_addr in error_detail['ip_addr_move_list'].get_value(): + ip_addresses.append(ip_addr.get_ipaddress()) + + resource_name = error_detail['resource_name'].get_value(). \ + decode('utf-16-le') + detail_msg = "IP Addresses: '%s', Resource Name: %s" \ + % ("', '".join(ip_addresses), resource_name) + return detail_msg + + def _get_symlink_error_detail_msg(self, error_detail): + flag = str(error_detail['flags']) + print_name = error_detail.get_print_name() + sub_name = error_detail.get_substitute_name() + detail_msg = "Flag: %s, Print Name: %s, Substitute Name: %s" \ + % (flag, print_name, sub_name) + return detail_msg + + +class ErrorContextId(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.2.1 SMB2 Error Context Response ErrorId + An identifier for the error context, it MUST be set to one of the following + values. + """ + SMB2_ERROR_ID_DEFAULT = 0x00000000 + SMB2_ERROR_ID_SHARE_REDIRECT = 0x53526472 + + +class SymbolicLinkErrorFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.2.2.1 Symbolic Link Error Response Flags + Specifies whether the substitute name is an absolute target path or a path + name relative to the directory containing the symbolic link + """ + SYMLINK_FLAG_ABSOLUTE = 0x00000000 + SYMLINK_FLAG_RELATIVE = 0x00000001 + + +class IpAddrType(object): + """ + [MS-SM2] v53.0 2017-09-15 + + 2.2.2.2.2.1 MOVE_DST_IPADDR structure Type + Indicates the type of the destionation IP address. + """ + MOVE_DST_IPADDR_V4 = 0x00000001 + MOVE_DST_IPADDR_V6 = 0x00000002 + + +class SMB2ErrorResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.2 SMB2 Error Response + The SMB2 Error Response packet is sent by the server to respond to a + request that has failed or encountered an error. This is only used in the + SMB 3.1.1 dialect and this code won't decode values based on older versions + """ + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=9, + )), + ('error_context_count', IntField( + size=1, + default=lambda s: len(s['error_data'].get_value()), + )), + ('reserved', IntField(size=1)), + ('byte_count', IntField( + size=4, + default=lambda s: len(s['error_data']), + )), + ('error_data', ListField( + size=lambda s: s['byte_count'].get_value(), + list_count=lambda s: s['error_context_count'].get_value(), + list_type=StructureField( + structure_type=SMB2ErrorContextResponse + ), + unpack_func=lambda s, d: self._error_data_value(s, d) + )), + ]) + super(SMB2ErrorResponse, self).__init__() + + def _error_data_value(self, structure, data): + context_responses = [] + while len(data) > 0: + response = SMB2ErrorContextResponse() + data = response.unpack(data) + context_responses.append(response) + + return context_responses + + +class SMB2ErrorContextResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.2.1 SMB2 ERROR Context Response + For the SMB dialect 3.1.1, the server formats the error data as an array of + SMB2 Error Context structures in the SMB2ErrorResponse message. + + """ + + def __init__(self): + self.fields = OrderedDict([ + ('error_data_length', IntField( + size=4, + default=lambda s: len(s['error_context_data']), + )), + ('error_id', EnumField( + size=4, + default=ErrorContextId.SMB2_ERROR_ID_DEFAULT, + enum_type=ErrorContextId + )), + ('error_context_data', BytesField( + size=lambda s: s['error_data_length'].get_value(), + )), + ]) + super(SMB2ErrorContextResponse, self).__init__() + + +class SMB2SymbolicLinkErrorResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.2.2.1 Symbolic Link Error Response + The Symbolic Link Error Response is used to indicate that a symbolic link + was encountered on the create. It describes the target path that the client + MUST use if it requires to follow the symbolic link. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('symlink_length', IntField( + size=4, + default=lambda s: len(s) - 4 + )), + ('symlink_error_tag', BytesField( + size=4, + default=b"\x53\x59\x4d\x4c" + )), + ('reparse_tag', BytesField( + size=4, + default=b"\x0c\x00\x00\xa0" + )), + ('reparse_data_length', IntField( + size=2, + default=lambda s: len(s['path_buffer']) + 12 + )), + # the len in utf-16-le bytes of the path beyond the substitute name + # of the original target, e.g. \\server\share\symlink\file.txt + # would be length of \file.txt in utf-16-le form, this is used by + # the client to find out what part of the original path to append + # to the substitute name returned by the server. + ('unparsed_path_length', IntField(size=2)), + ('substitute_name_offset', IntField(size=2)), + ('substitute_name_length', IntField(size=2)), + ('print_name_offset', IntField(size=2)), + ('print_name_length', IntField(size=2)), + ('flags', EnumField( + size=4, + enum_type=SymbolicLinkErrorFlags + )), + # use the get/set_name functions to get/set these values as they + # also (d)encode the text and set the length and offset accordingly + ('path_buffer', BytesField( + size=lambda s: self._get_name_length(s, True) + )) + ]) + super(SMB2SymbolicLinkErrorResponse, self).__init__() + + def _get_name_length(self, structure, first): + print_name_len = structure['print_name_length'].get_value() + sub_name_len = structure['substitute_name_length'].get_value() + return print_name_len + sub_name_len + + def set_name(self, print_name, substitute_name): + """ + Set's the path_buffer and print/substitute name length of the message + with the values passed in. These values should be a string and not a + byte string as it is encoded in this function. + + :param print_name: The print name string to set + :param substitute_name: The substitute name string to set + """ + print_bytes = print_name.encode('utf-16-le') + sub_bytes = substitute_name.encode('utf-16-le') + path_buffer = print_bytes + sub_bytes + + self['print_name_offset'].set_value(0) + self['print_name_length'].set_value(len(print_bytes)) + self['substitute_name_offset'].set_value(len(print_bytes)) + self['substitute_name_length'].set_value(len(sub_bytes)) + self['path_buffer'].set_value(path_buffer) + + def get_print_name(self): + offset = self['print_name_offset'].get_value() + length = self['print_name_length'].get_value() + name_bytes = self['path_buffer'].get_value()[offset:offset + length] + return name_bytes.decode('utf-16-le') + + def get_substitute_name(self): + offset = self['substitute_name_offset'].get_value() + length = self['substitute_name_length'].get_value() + name_bytes = self['path_buffer'].get_value()[offset:offset + length] + return name_bytes.decode('utf-16-le') + + +class SMB2ShareRedirectErrorContext(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.2.2.2 Share Redirect Error Context Response + Response to a Tree Connect with the + SMB2_TREE_CONNECT_FLAG_REDIRECT_TO_OWNER flag set. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=4, + default=lambda s: len(s) + )), + ('notification_type', IntField( + size=4, + default=3 + )), + ('resource_name_offset', IntField( + size=4, + default=lambda s: self._resource_name_offset(s) + )), + ('resource_name_length', IntField( + size=4, + default=lambda s: len(s['resource_name']) + )), + ('flags', IntField( + size=2, + default=0 + )), + ('target_type', IntField( + size=2, + default=0 + )), + ('ip_addr_count', IntField( + size=4, + default=lambda s: len(s['ip_addr_move_list'].get_value()) + )), + ('ip_addr_move_list', ListField( + size=lambda s: s['ip_addr_count'].get_value() * 24, + list_count=lambda s: s['ip_addr_count'].get_value(), + list_type=StructureField( + size=24, + structure_type=SMB2MoveDstIpAddrStructure + ) + )), + ('resource_name', BytesField( + size=lambda s: s['resource_name_length'].get_value() + )) + ]) + super(SMB2ShareRedirectErrorContext, self).__init__() + + def _resource_name_offset(self, structure): + min_structure_size = 24 + addr_list_size = len(structure['ip_addr_move_list']) + return min_structure_size + addr_list_size + + +class SMB2MoveDstIpAddrStructure(Structure): + """ + [MS-SMB2] c53.0 2017-09-15 + + 2.2.2.2.2.1 MOVE_DST_IPADDR structure + Used to indicate the destination IP address. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('type', EnumField( + size=4, + enum_type=IpAddrType + )), + ('reserved', IntField(size=4)), + ('ip_address', BytesField( + size=lambda s: self._ip_address_size(s) + )), + ('reserved2', BytesField( + size=lambda s: self._reserved2_size(s), + default=lambda s: b"\x00" * self._reserved2_size(s) + )) + ]) + super(SMB2MoveDstIpAddrStructure, self).__init__() + + def _ip_address_size(self, structure): + if structure['type'].get_value() == IpAddrType.MOVE_DST_IPADDR_V4: + return 4 + else: + return 16 + + def _reserved2_size(self, structure): + if structure['type'].get_value() == IpAddrType.MOVE_DST_IPADDR_V4: + return 12 + else: + return 0 + + def get_ipaddress(self): + # get's the IP address in a human readable format + ip_address = self['ip_address'].get_value() + if self['type'].get_value() == IpAddrType.MOVE_DST_IPADDR_V4: + return socket.inet_ntoa(ip_address) + else: + addr = binascii.hexlify(ip_address).decode('utf-8') + return ":".join([addr[i:i + 4] for i in range(0, len(addr), 4)]) + + def set_ipaddress(self, address): + # set's the IP address from a human readable format, for IPv6, this + # needs to be the full IPv6 address + if self['type'].get_value() == IpAddrType.MOVE_DST_IPADDR_V4: + self['ip_address'].set_value(socket.inet_aton(address)) + else: + addr = address.replace(":", "") + if len(addr) != 32: + raise ValueError("When setting an IPv6 address, it must be in " + "the full form without concatenation") + self['ip_address'].set_value(binascii.unhexlify(addr)) diff --git a/smbprotocol/ioctl.py b/smbprotocol/ioctl.py new file mode 100644 index 00000000..67d81b19 --- /dev/null +++ b/smbprotocol/ioctl.py @@ -0,0 +1,570 @@ +import binascii +import socket + +from smbprotocol.structure import BytesField, EnumField, FlagField, IntField, \ + ListField, Structure, StructureField, UuidField +from smbprotocol.connection import Capabilities, Commands, Dialects, \ + SecurityMode + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + + +class CtlCode(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31 SMB2 IOCTL Request CtlCode + The control code of the FSCTL_IOCTL method. + """ + FSCTL_DFS_GET_REFERRALS = 0x00060194 + FSCTL_PIPE_PEEK = 0x0011400C + FSCTL_PIPE_WAIT = 0x00110018 + FSCTL_PIPE_TRANSCEIVE = 0x0011C017 + FSCTL_SRV_COPYCHUNK = 0x001440F2 + FSCTL_SRV_ENUMERATE_SNAPSHOTS = 0x00144064 + FSCTL_SRV_REQUEST_RESUME_KEY = 0x00140078 + FSCTL_SRV_READ_HASH = 0x001441bb + FSCTL_SRV_COPYCHUNK_WRITE = 0x001480F2 + FSCTL_LMR_REQUEST_RESILIENCY = 0x001401D4 + FSCTL_QUERY_NETWORK_INTERFACE_INFO = 0x001401FC + FSCTL_SET_REPARSE_POINT = 0x000900A4 + FSCTL_DFS_GET_REFERRALS_EX = 0x000601B0 + FSCTL_FILE_LEVEL_TRIM = 0x00098208 + FSCTL_VALIDATE_NEGOTIATE_INFO = 0x00140204 + + +class IOCTLFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31 SMB2 IOCTL Request Flags + A flag that indicates how to process the operation + """ + SMB2_0_IOCTL_IS_IOCTL = 0x00000000 + SMB2_0_IOCTL_IS_FSCTL = 0x00000001 + + +class HashVersion(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31.2 SRV_READ_HASH Request HashVersion + The version number of the algorithm used to create the Content Information. + """ + SRV_HASH_VER_1 = 0x00000001 + SRV_HASH_VER_2 = 0x00000002 + + +class HashRetrievalType(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31.2 SRV_READ_HASH Request HashRetrievalType + Indicates the nature of the Offset field in am SMB2SrvReadHashRequest + packet. + """ + SRV_HASH_RETRIEVE_HASH_BASED = 0x00000001 + SRV_HASH_RETRIEVE_FILE_BASED = 0x00000002 + + +class IfCapability(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.5 NETWORK_INTERFACE_INFO Response Capability + The capabilities of the network interface + """ + RSS_CAPABLE = 0x00000001 + RDMA_CAPABLE = 0x00000002 + + +class SockAddrFamily(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.5.1 SOCKADDR_STORAGE Family + The address family of the socket. + """ + INTER_NETWORK = 0x0002 + INTER_NETWORK_V6 = 0x0017 + + +class SMB2IOCTLRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31 SMB2 IOCTL Request + Send by the client to issue an implementation-specific file system control + or device control command across the network. + """ + COMMAND = Commands.SMB2_IOCTL + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField(size=2, default=57)), + ('reserved', IntField(size=2, default=0)), + ('ctl_code', EnumField( + size=4, + enum_type=CtlCode, + )), + ('file_id', BytesField(size=16)), + ('input_offset', IntField( + size=4, + default=lambda s: self._buffer_offset_value(s) + )), + ('input_count', IntField( + size=4, + default=lambda s: len(s['buffer']), + )), + ('max_input_response', IntField(size=4)), + ('output_offset', IntField( + size=4, + default=lambda s: self._buffer_offset_value(s) + )), + ('output_count', IntField(size=4, default=0)), + ('max_output_response', IntField(size=4)), + ('flags', EnumField( + size=4, + enum_type=IOCTLFlags, + )), + ('reserved2', IntField(size=4, default=0)), + ('buffer', BytesField( + size=lambda s: s['input_count'].get_value() + )) + ]) + super(SMB2IOCTLRequest, self).__init__() + + def _buffer_offset_value(self, structure): + # The offset from the beginning of the SMB2 header to the value of the + # buffer, 0 if no buffer is set + if len(structure['buffer']) > 0: + header_size = 64 + request_size = structure['structure_size'].get_value() + return header_size + request_size - 1 + else: + return 0 + + +class SMB2SrvCopyChunkCopy(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31.1 SRV_COPYCHUNK_COPY + Sent in an SMB2 IOCTL Request by the client to initiate a server-side copy + of data. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('source_key', BytesField(size=24)), + ('chunk_count', IntField( + size=4, + default=lambda s: len(s['chunks'].get_value()) + )), + ('reserved', IntField(size=4)), + ('chunks', ListField( + size=lambda s: s['chunk_count'].get_value() * 24, + list_count=lambda s: s['chunk_count'].get_value(), + list_type=StructureField( + size=24, + structure_type=SMB2SrvCopyChunk + ) + )) + ]) + super(SMB2SrvCopyChunkCopy, self).__init__() + + +class SMB2SrvCopyChunk(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31.1.1 SRC_COPYCHUNK + Packet sent in the Chunks array of an SRC_COPY_CHUNK_COPY packet to + describe an individual data range to copy. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('source_offset', IntField(size=8)), + ('target_offset', IntField(size=8)), + ('length', IntField(size=4)), + ('reserved', IntField(size=4)) + ]) + super(SMB2SrvCopyChunk, self).__init__() + + +class SMB2SrvReadHashRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31.2 SRC_READ_HASH Request + Sent by the client in an SMB2 IOCTL Request to retrieve data from the + Content Information File associated with a specified file. + Not valid for the SMB 2.0.2 dialect. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('hash_type', IntField( + size=4, + default=1 # SRV_HASH_TYPE_PEER_DIST + )), + ('hash_version', EnumField( + size=4, + enum_type=HashVersion + )), + ('hash_retrieval_type', EnumField( + size=4, + enum_type=HashRetrievalType + )), + ('length', IntField(size=4)), + ('offset', IntField(size=8)) + ]) + super(SMB2SrvReadHashRequest, self).__init__() + + +class SMB2SrvNetworkResiliencyRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31.3 NETWORK_RESILIENCY_REQUEST Request + Sent by the client to request resiliency for a specified open file. + Not valid for the SMB 2.0.2 dialect. + """ + + def __init__(self): + self.fields = OrderedDict([ + # timeout is in milliseconds + ('timeout', IntField(size=4)), + ('reserved', IntField(size=4)) + ]) + super(SMB2SrvNetworkResiliencyRequest, self).__init__() + + +class SMB2ValidateNegotiateInfoRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31.4 VALIDATE_NEGOTIATE_INFO Request + Packet sent to the server to request validation of a previous SMB 2 + NEGOTIATE request. + Only valid for the SMB 3.0 and 3.0.2 dialects. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('capabilities', FlagField( + size=4, + flag_type=Capabilities, + )), + ('guid', UuidField()), + ('security_mode', EnumField( + size=2, + enum_type=SecurityMode, + )), + ('dialect_count', IntField( + size=2, + default=lambda s: len(s['dialects'].get_value()) + )), + ('dialects', ListField( + size=lambda s: s['dialect_count'].get_value() * 2, + list_count=lambda s: s['dialect_count'].get_value(), + list_type=EnumField(size=2, enum_type=Dialects), + )) + ]) + super(SMB2ValidateNegotiateInfoRequest, self).__init__() + + +class SMB2IOCTLResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32 SMB2 IOCTL Response + Sent by the server to transmit the results of a client SMB2 IOCTL Request. + """ + COMMAND = Commands.SMB2_IOCTL + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField(size=2, default=49)), + ('reserved', IntField(size=2, default=0)), + ('ctl_code', EnumField( + size=4, + enum_type=CtlCode, + )), + ('file_id', BytesField(size=16)), + ('input_offset', IntField(size=4)), + ('input_count', IntField(size=4)), + ('output_offset', IntField(size=4)), + ('output_count', IntField(size=4)), + ('flags', IntField(size=4, default=0)), + ('reserved2', IntField(size=4, default=0)), + ('buffer', BytesField( + size=lambda s: s['output_count'].get_value(), + )) + ]) + super(SMB2IOCTLResponse, self).__init__() + + +class SMB2SrvCopyChunkResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.1 SRV_COPYCHUNK_RESPONSE + """ + + def __init__(self): + self.fields = OrderedDict([ + ('chunks_written', IntField(size=4)), + ('chunk_bytes_written', IntField(size=4)), + ('total_bytes_written', IntField(size=4)) + ]) + super(SMB2SrvCopyChunkResponse, self).__init__() + + +class SMB2SrvSnapshotArray(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.2 SRV_SNAPSHOT_ARRAY + Sent by the server in response to an SMB2IOCTLResponse for the + FSCTL_SRV_ENUMERATE_SNAPSHOTS request. + """ + + def __init__(self): + # TODO: validate this further when working with actual snapshots + self.fields = OrderedDict([ + ('number_of_snapshots', IntField(size=4)), + ('number_of_snapshots_returned', IntField(size=4)), + ('snapshot_array_size', IntField(size=4)), + ('snapshots', BytesField()) + ]) + super(SMB2SrvSnapshotArray, self).__init__() + + +class SMB2SrvRequestResumeKey(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.3 SRV_REQUEST_RESUME_KEY Response + Sent by the server in response to an SMB2IOCTLResponse for the + FSCTL_SRV_REQUEST_RESUME_KEY request. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('resume_key', BytesField(size=24)), + ('context_length', IntField( + size=4, + default=0 + )) + ]) + super(SMB2SrvRequestResumeKey, self).__init__() + + +class SMB2NetworkInterfaceInfo(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.5 NETWORK_INTERFACE_INFO Response + The NETWORK_INTERFACE_INFO returned to the client in an SMB2IOCTLResposne + for the FSCTL_QUERY_NETWORK_INTERFACE_INFO. + + Use the pack_multiple and unpack_multiple to handle multiple interfaces + that are returned in the SMB2IOCTLResponse + """ + + def __init__(self): + self.fields = OrderedDict([ + # 0 if no more network interfaces + ('next', IntField(size=4)), + ('if_index', IntField(size=4)), + ('capability', FlagField( + size=4, + flag_type=IfCapability + )), + ('reserved', IntField(size=4)), + ('link_speed', IntField(size=8)), + ('sock_addr_storage', StructureField( + size=128, + structure_type=SockAddrStorage + )) + ]) + super(SMB2NetworkInterfaceInfo, self).__init__() + + @staticmethod + def pack_multiple(messages): + """ + Packs a list of SMB2NetworkInterfaceInfo messages and set's the next + value accordingly. The byte value returned is then attached to the + SMBIOCTLResponse message. + + :param messages: List of SMB2NetworkInterfaceInfo messages + :return: bytes of the packed messages + """ + data = b"" + msg_count = len(messages) + for i, msg in enumerate(messages): + if i == msg_count - 1: + msg['next'] = 0 + else: + msg['next'] = 152 + data += msg.pack() + return data + + @staticmethod + def unpack_multiple(data): + """ + Get's a list of SMB2NetworkInterfaceInfo messages from the byte value + passed in. This is the raw buffer value that is set on the + SMB2IOCTLResponse message. + + :param data: bytes of the messages + :return: List of SMB2NetworkInterfaceInfo messages + """ + chunks = [] + while data: + info = SMB2NetworkInterfaceInfo() + data = info.unpack(data) + chunks.append(info) + + return chunks + + +class SockAddrStorage(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.5.1 SOCKADDR_STORAGE + Socket Address information. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('family', EnumField( + size=2, + enum_type=SockAddrFamily + )), + ('buffer', StructureField( + size=lambda s: self._get_buffer_size(s), + structure_type=lambda s: self._get_buffer_structure_type(s) + )), + ('reserved', BytesField( + size=lambda s: self._get_reserved_size(s), + default=lambda s: b"\x00" * self._get_reserved_size(s) + )) + ]) + super(SockAddrStorage, self).__init__() + + def _get_buffer_size(self, structure): + if structure['family'].get_value() == SockAddrFamily.INTER_NETWORK: + return 14 + else: + return 26 + + def _get_buffer_structure_type(self, structure): + if structure['family'].get_value() == SockAddrFamily.INTER_NETWORK: + return SockAddrIn + else: + return SockAddrIn6 + + def _get_reserved_size(self, structure): + if structure['family'].get_value() == SockAddrFamily.INTER_NETWORK: + return 112 + else: + return 100 + + +class SockAddrIn(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.5.1.1 SOCKADDR_IN + Socket address information for an IPv4 address + """ + + def __init__(self): + self.fields = OrderedDict([ + ('port', IntField(size=2)), + ('ipv4_address', BytesField(size=4)), + ('reserved', IntField(size=8)) + ]) + super(SockAddrIn, self).__init__() + + def get_ipaddress(self): + addr_bytes = self['ipv4_address'].get_value() + return socket.inet_ntoa(addr_bytes) + + def set_ipaddress(self, address): + # set's the ipv4 address field from the address string passed in, this + # needs to be the full ipv4 address including periods, e.g. + # 192.168.1.1 + addr_bytes = socket.inet_aton(address) + self['ipv4_address'].set_value(addr_bytes) + + +class SockAddrIn6(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.5.1.2 SOCKADDR_IN6 + Socket address information for an IPv6 address + """ + + def __init__(self): + self.fields = OrderedDict([ + ('port', IntField(size=2)), + ('flow_info', IntField(size=4)), + ('ipv6_address', BytesField(size=16)), + ('scope_id', IntField(size=4)) + ]) + super(SockAddrIn6, self).__init__() + + def get_ipaddress(self): + # get's the full IPv6 Address, note this is the full address and has + # not been concatenated + addr_bytes = self['ipv6_address'].get_value() + address = binascii.hexlify(addr_bytes).decode('utf-8') + return ":".join([address[i:i + 4] for i in range(0, len(address), 4)]) + + def set_ipaddress(self, address): + # set's the ipv6_address field from the address passed in, note this + # needs to be the full ipv6 address, + # e.g. fe80:0000:0000:0000:0000:0000:0000:0000 and not any short form + address = address.replace(":", "") + if len(address) != 32: + raise ValueError("When setting an IPv6 address, it must be in the " + "full form without concatenation") + self['ipv6_address'].set_value(binascii.unhexlify(address)) + + +class SMB2ValidateNegotiateInfoResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.32.6 VALIDATE_NEGOTIATE_INFO Response + Packet sent by the server on a request validation of SMB 2 negotiate + request. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('capabilities', FlagField( + size=4, + flag_type=Capabilities, + )), + ('guid', UuidField()), + ('security_mode', EnumField( + size=2, + enum_type=SecurityMode, + enum_strict=False + )), + ('dialect', EnumField( + size=2, + enum_type=Dialects + )) + ]) + super(SMB2ValidateNegotiateInfoResponse, self).__init__() diff --git a/smbprotocol/open.py b/smbprotocol/open.py new file mode 100644 index 00000000..70949251 --- /dev/null +++ b/smbprotocol/open.py @@ -0,0 +1,1351 @@ +import logging + +import smbprotocol.create_contexts +import smbprotocol.query_info +from smbprotocol.exceptions import SMBException, SMBResponseException, \ + SMBUnsupportedFeature +from smbprotocol.structure import BytesField, DateTimeField, EnumField, \ + FlagField, IntField, ListField, Structure, StructureField +from smbprotocol.connection import Commands, Dialects, NtStatus + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + +log = logging.getLogger(__name__) + + +class RequestedOplockLevel(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31 SMB2 CREATE Request RequestedOplockLevel + The requested oplock level used when creating/accessing a file. + """ + SMB2_OPLOCK_LEVEL_NONE = 0x00 + SMB2_OPLOCK_LEVEL_II = 0x01 + SMB2_OPLOCK_LEVEL_EXCLUSIVE = 0x08 + SMB2_OPLOCK_LEVEL_BATCH = 0x09 + SMB2_OPLOCK_LEVEL_LEASE = 0xFF + + +class ImpersonationLevel(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31 SMB2 CREATE Request ImpersonationLevel + The impersonation level requested by the application in a create request. + """ + Anonymous = 0x0 + Identification = 0x1 + Impersonation = 0x2 + Delegate = 0x3 + + +class ShareAccess(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31 SMB2 CREATE Request ShareAccess + The sharing mode for the open + """ + FILE_SHARE_READ = 0x1 + FILE_SHARE_WRITE = 0x2 + FILE_SHARE_DELETE = 0x4 + + +class CreateDisposition(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31 SMB2 CREATE Request CreateDisposition + Defines the action the server must take if the file that is specific + already exists. + """ + FILE_SUPERSEDE = 0x0 + FILE_OPEN = 0x1 + FILE_CREATE = 0x2 + FILE_OPEN_IF = 0x3 + FILE_OVERWRITE = 0x4 + FILE_OVERWRITE_IF = 0x5 + + +class CreateOptions(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.31 SMB2 CREATE Request CreateOptions + Specifies the options to be applied when creating or opening the file + """ + FILE_DIRECTORY_FILE = 0x00000001 + FILE_WRITE_THROUGH = 0x00000002 + FILE_SEQUENTIAL_ONLY = 0x00000004 + FILE_NO_INTERMEDIATE_BUFFERING = 0x00000008 + FILE_SYNCHRONOUS_IO_ALERT = 0x00000010 + FILE_SYNCHRONOUS_IO_NONALERT = 0x00000020 + FILE_NON_DIRECTORY_FILE = 0x00000040 + FILE_COMPLETE_IF_OPLOCKED = 0x00000100 + FILE_NO_EA_KNOWLEDGE = 0x00000200 + FILE_RANDOM_ACCESS = 0x00000800 + FILE_DELETE_ON_CLOSE = 0x00001000 + FILE_OPEN_BY_FILE_ID = 0x00002000 + FILE_OPEN_FOR_BACKUP_INTENT = 0x00004000 + FILE_NO_COMPRESSION = 0x00008000 + FILE_OPEN_REMOTE_INSTANCE = 0x00000400 + FILE_OPEN_REQUIRING_OPLOCK = 0x00010000 + FILE_DISALLOW_EXCLUSIVE = 0x00020000 + FILE_RESERVE_OPFILTER = 0x00100000 + FILE_OPEN_REPARSE_POINT = 0x00200000 + FILE_OPEN_NO_RECALL = 0x00400000 + FILE_OPEN_FOR_FREE_SPACE_QUERY = 0x00800000 + + +class FilePipePrinterAccessMask(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.13.1.1 File_Pipe_Printer_Access_Mask + Access Mask flag values to be used when accessing a file, pipe, or printer + """ + FILE_READ_DATA = 0x00000001 + FILE_WRITE_DATA = 0x00000002 + FILE_APPEND_DATA = 0x00000004 + FILE_READ_EA = 0x00000008 + FILE_WRITE_EA = 0x00000010 + FILE_DELETE_CHILD = 0x00000040 + FILE_EXECUTE = 0x00000020 + FILE_READ_ATTRIBUTES = 0x00000080 + FILE_WRITE_ATTRIBUTES = 0x00000100 + DELETE = 0x00010000 + READ_CONTROL = 0x00020000 + WRITE_DAC = 0x00040000 + WRITE_OWNER = 0x00080000 + SYNCHRONIZE = 0x00100000 + ACCESS_SYSTEM_SECURITY = 0x01000000 + MAXIMUM_ALLOWED = 0x02000000 + GENERIC_ALL = 0x10000000 + GENERIC_EXECUTE = 0x20000000 + GENERIC_WRITE = 0x40000000 + GENERIC_READ = 0x80000000 + + +class DirectoryAccessMask(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.13.1.2 Directory_Access_Mask + Access Mask flag values to be used when accessing a directory + """ + FILE_LIST_DIRECTORY = 0x00000001 + FILE_ADD_FILE = 0x00000002 + FILE_ADD_SUBDIRECTORY = 0x00000004 + FILE_READ_EA = 0x00000008 + FILE_WRITE_EA = 0x00000010 + FILE_TRAVERSE = 0x00000020 + FILE_DELETE_CHILD = 0x00000040 + FILE_READ_ATTRIBUTES = 0x00000080 + FILE_WRITE_ATTRIBUTES = 0x00000100 + DELETE = 0x00010000 + READ_CONTROL = 0x00020000 + WRITE_DAC = 0x00040000 + WRITE_OWNER = 0x00080000 + SYNCHRONIZE = 0x00100000 + ACCESS_SYSTEM_SECURITY = 0x01000000 + MAXIMUM_ALLOWED = 0x02000000 + GENERIC_ALL = 0x10000000 + GENERIC_EXECUTE = 0x20000000 + GENERIC_WRITE = 0x40000000 + GENERIC_READ = 0x80000000 + + +class FileFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.14 SMB2 CREATE Response Flags + Flag that details info about the file that was opened. + """ + SMB2_CREATE_FLAG_REPARSEPOINT = 0x1 + + +class CreateAction(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.14 SMB2 CREATE Response Flags + The action taken in establishing the open. + """ + FILE_SUPERSEDED = 0x0 + FILE_OPENED = 0x1 + FILE_CREATED = 0x2 + FILE_OVERWRITTEN = 0x3 + + +class FileAttributes(object): + """ + [MS-FSCC] + + 2.6 File Attributes + Combination of file attributes for a file or directory + """ + FILE_ATTRIBUTE_ARCHIVE = 0x00000020 + FILE_ATTRIBUTE_COMPRESSED = 0x00000800 + FILE_ATTRIBUTE_DIRECTORY = 0x00000010 + FILE_ATTRIBUTE_ENCRYPTED = 0x00004000 + FILE_ATTRIBUTE_HIDDEN = 0x00000002 + FILE_ATTRIBUTE_NORMAL = 0x00000080 + FILE_ATTRIBUTE_NOT_CONTENT_INDEXED = 0x00002000 + FILE_ATTRIBUTE_OFFLINE = 0x00001000 + FILE_ATTRIBUTE_READONLY = 0x00000001 + FILE_ATTRIBUTE_REPARSE_POINT = 0x00000400 + FILE_ATTRIBUTE_SPARSE_FILE = 0x00000200 + FILE_ATTRIBUTE_SYSTEM = 0x00000004 + FILE_ATTRIBUTE_TEMPORARY = 0x00000100 + FILE_ATTRIBUTE_INTEGRITY_STREAM = 0x00008000 + FILE_ATTRIBUTE_NO_SCRUB_DATA = 0x00020000 + + +class CloseFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.15 SMB2 CLOSE Request Flags + Flags to indicate how to process the operation + """ + SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB = 0x01 + + +class ReadFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.19 SMB2 READ Request Flags + Read flags for SMB 3.0.2 and newer dialects + """ + SMB2_READFLAG_READ_UNBUFFERED = 0x01 + + +class ReadWriteChannel(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.19/21 SMB2 READ/Write Request Channel + Channel information for an SMB2 READ Request message + """ + SMB2_CHANNEL_NONE = 0x0 + SMB2_CHANNEL_RDMA_V1 = 0x1 + SMB2_CHANNEL_RDMA_V1_INVALIDATE = 0x2 + + +class WriteFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.21 SMB2 WRITE Request Flags + Flags to indicate how to process the operation + """ + SMB2_WRITEFLAG_WRITE_THROUGH = 0x00000001 + SMB2_WRITEFLAG_WRITE_UNBUFFERED = 0x00000002 + + +class FileInformationClass(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.33 SMB2 QUERY_DIRECTORY Request FileInformationClass + Describe the format the data MUST be returned in. The format structure must + is specified in https://msdn.microsoft.com/en-us/library/cc232064.aspx + """ + FILE_DIRECTORY_INFORMATION = 0x01 + FILE_FULL_DIRECTORY_INFORMATION = 0x02 + FILE_ID_FULL_DIRECTORY_INFORMATION = 0x26 + FILE_BOTH_DIRECTORY_INFORMATION = 0x03 + FILE_ID_BOTH_DIRECTORY_INFORMATION = 0x25 + FILE_NAMES_INFORMATION = 0x0C + + +class QueryDirectoryFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.33 SMB2 QUERY_DIRECTORY Request Flags + Indicates how the query directory operation MUST be processed. + """ + SMB2_RESTART_SCANS = 0x01 + SMB2_RETURN_SINGLE_ENTRY = 0x02 + SMB2_INDEX_SPECIFIED = 0x04 + SMB2_REOPEN = 0x10 + + +class SMB2CreateRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.13 SMB2 CREATE Request + The SMB2 Create Request packet is sent by a client to request either + creation of or access to a file. + """ + COMMAND = Commands.SMB2_CREATE + + def __init__(self): + # pep 80 char issues force me to define this here + create_con_req = smbprotocol.create_contexts.SMB2CreateContextRequest + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=57, + )), + ('security_flags', IntField(size=1)), + ('requested_oplock_level', EnumField( + size=1, + enum_type=RequestedOplockLevel + )), + ('impersonation_level', EnumField( + size=4, + enum_type=ImpersonationLevel + )), + ('smb_create_flags', IntField(size=8)), + ('reserved', IntField(size=8)), + ('desired_access', IntField(size=4)), + ('file_attributes', IntField(size=4)), + ('share_access', FlagField( + size=4, + flag_type=ShareAccess + )), + ('create_disposition', EnumField( + size=4, + enum_type=CreateDisposition + )), + ('create_options', FlagField( + size=4, + flag_type=CreateOptions + )), + ('name_offset', IntField( + size=2, + default=120 # (header size 64) + (structure size 56) + )), + ('name_length', IntField( + size=2, + default=lambda s: len(s['buffer_path']) + )), + ('create_contexts_offset', IntField( + size=4, + default=lambda s: self._create_contexts_offset(s) + )), + ('create_contexts_length', IntField( + size=4, + default=lambda s: len(s['buffer_contexts']) + )), + # Technically these are all under buffer but we split it to make + # things easier + ('buffer_path', BytesField( + size=lambda s: s['name_length'].get_value(), + )), + ('padding', BytesField( + size=lambda s: self._padding_size(s), + default=lambda s: b"\x00" * self._padding_size(s) + )), + ('buffer_contexts', ListField( + size=lambda s: s['create_contexts_length'].get_value(), + list_type=StructureField( + structure_type=create_con_req + ), + unpack_func=lambda s, d: self._buffer_context_list(s, d) + )) + ]) + super(SMB2CreateRequest, self).__init__() + + def _create_contexts_offset(self, structure): + if len(structure['buffer_contexts']) == 0: + return 0 + else: + return structure['name_offset'].get_value() + \ + len(structure['padding']) + len(structure['buffer_path']) + + def _padding_size(self, structure): + # no padding is needed if there are no contexts + if structure['create_contexts_length'].get_value() == 0: + return 0 + + mod = structure['name_length'].get_value() % 8 + return 0 if mod == 0 else 8 - mod + + def _buffer_context_list(self, structure, data): + context_list = [] + last_context = data == b"" + while not last_context: + create_context = \ + smbprotocol.create_contexts.SMB2CreateContextRequest() + data = create_context.unpack(data) + context_list.append(create_context) + last_context = create_context['next'].get_value() == 0 + + return context_list + + +class SMB2CreateResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.14 SMB2 CREATE Response + The SMB2 Create Response packet is sent by the server to an SMB2 CREATE + Request. + """ + COMMAND = Commands.SMB2_CREATE + + def __init__(self): + create_con_req = smbprotocol.create_contexts.SMB2CreateContextRequest + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=89 + )), + ('oplock_level', EnumField( + size=1, + enum_type=RequestedOplockLevel + )), + ('flag', FlagField( + size=1, + flag_type=FileFlags + )), + ('create_action', EnumField( + size=4, + enum_type=CreateAction + )), + ('creation_time', DateTimeField(size=8)), + ('last_access_time', DateTimeField(size=8)), + ('last_write_time', DateTimeField(size=8)), + ('change_time', DateTimeField(size=8)), + ('allocation_size', IntField(size=8)), + ('end_of_file', IntField(size=8)), + ('file_attributes', FlagField( + size=4, + flag_type=FileAttributes + )), + ('reserved2', IntField(size=4)), + ('file_id', BytesField( + size=16 + )), + ('create_contexts_offset', IntField( + size=4, + default=lambda s: self._create_contexts_offset(s) + )), + ('create_contexts_length', IntField( + size=4, + default=lambda s: len(s['buffer']) + )), + ('buffer', ListField( + size=lambda s: s['create_contexts_length'].get_value(), + list_type=StructureField( + structure_type=create_con_req + ), + unpack_func=lambda s, d: self._buffer_context_list(s, d) + )) + ]) + super(SMB2CreateResponse, self).__init__() + + def _create_contexts_offset(self, structure): + if len(structure['buffer']) == 0: + return 0 + else: + return 152 + + def _buffer_context_list(self, structure, data): + context_list = [] + last_context = data == b"" + while not last_context: + create_context = \ + smbprotocol.create_contexts.SMB2CreateContextRequest() + data = create_context.unpack(data) + context_list.append(create_context) + last_context = create_context['next'].get_value() == 0 + + return context_list + + +class SMB2CloseRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.15 SMB2 CLOSE Request + Used by the client to close an instance of a file + """ + COMMAND = Commands.SMB2_CLOSE + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=24 + )), + ('flags', FlagField( + size=2, + flag_type=CloseFlags + )), + ('reserved', IntField(size=4)), + ('file_id', BytesField(size=16)) + ]) + super(SMB2CloseRequest, self).__init__() + + +class SMB2CloseResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.16 SMB2 CLOSE Response + The response of a SMB2 CLOSE Request + """ + COMMAND = Commands.SMB2_CLOSE + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=60 + )), + ('flags', FlagField( + size=2, + flag_type=CloseFlags + )), + ('reserved', IntField(size=4)), + ('creation_time', DateTimeField()), + ('last_access_time', DateTimeField()), + ('last_write_time', DateTimeField()), + ('change_time', DateTimeField()), + ('allocation_size', IntField(size=8)), + ('end_of_file', IntField(size=8)), + ('file_attributes', FlagField( + size=4, + flag_type=FileAttributes + )) + ]) + super(SMB2CloseResponse, self).__init__() + + +class SMB2FlushRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.17 SMB2 FLUSH Request + Flush all cached file information for a specified open of a file to the + persistent store that backs the file. + """ + COMMAND = Commands.SMB2_FLUSH + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=24 + )), + ('reserved1', IntField(size=2)), + ('reserved2', IntField(size=4)), + ('file_id', BytesField(size=16)) + ]) + super(SMB2FlushRequest, self).__init__() + + +class SMB2FlushResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.18 SMB2 FLUSH Response + SMB2 FLUSH Response packet sent by the server. + """ + COMMAND = Commands.SMB2_FLUSH + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=4 + )), + ('reserved', IntField(size=2)) + ]) + super(SMB2FlushResponse, self).__init__() + + +class SMB2ReadRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.19 SMB2 READ Request + The request is used to run a read operation on the file specified. + """ + COMMAND = Commands.SMB2_READ + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=49 + )), + ('padding', IntField(size=1)), + ('flags', FlagField( + size=1, + flag_type=ReadFlags + )), + ('length', IntField( + size=4 + )), + ('offset', IntField( + size=8 + )), + ('file_id', BytesField(size=16)), + ('minimum_count', IntField( + size=4 + )), + ('channel', FlagField( + size=4, + flag_type=ReadWriteChannel + )), + ('remaining_bytes', IntField(size=4)), + ('read_channel_info_offset', IntField( + size=2, + default=lambda s: self._get_read_channel_info_offset(s) + )), + ('read_channel_info_length', IntField( + size=2, + default=lambda s: self._get_read_channel_info_length(s) + )), + ('buffer', BytesField( + size=lambda s: self._get_buffer_length(s), + default=b"\x00" + )) + ]) + super(SMB2ReadRequest, self).__init__() + + def _get_read_channel_info_offset(self, structure): + if structure['channel'].get_value() == 0: + return 0 + else: + return 64 + structure['structure_size'].get_value() - 1 + + def _get_read_channel_info_length(self, structure): + if structure['channel'].get_value() == 0: + return 0 + else: + return len(structure['buffer'].get_value()) + + def _get_buffer_length(self, structure): + # buffer should contain 1 byte of \x00 and not be empty + if structure['channel'].get_value() == 0: + return 1 + else: + return structure['read_channel_info_length'].get_value() + + +class SMB2ReadResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.20 SMB2 READ Response + Response to an SMB2 READ Request. + """ + COMMAND = Commands.SMB2_READ + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=17 + )), + ('data_offset', IntField(size=1)), + ('reserved', IntField(size=1)), + ('data_length', IntField( + size=4, + default=lambda s: len(s['buffer']) + )), + ('data_remaining', IntField(size=4)), + ('reserved2', IntField(size=4)), + ('buffer', BytesField( + size=lambda s: s['data_length'].get_value() + )) + ]) + super(SMB2ReadResponse, self).__init__() + + +class SMB2WriteRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.21 SMB2 WRITE Request + A write packet to sent to an open file or named pipe on the server + """ + COMMAND = Commands.SMB2_WRITE + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=49 + )), + ('data_offset', IntField( # offset to the buffer field + size=2, + default=0x70 # seems to be hardcoded to this value + )), + ('length', IntField( + size=4, + default=lambda s: len(s['buffer']) + )), + ('offset', IntField(size=8)), # the offset in the file of the data + ('file_id', BytesField(size=16)), + ('channel', FlagField( + size=4, + flag_type=ReadWriteChannel + )), + ('remaining_bytes', IntField(size=4)), + ('write_channel_info_offset', IntField( + size=2, + default=lambda s: self._get_write_channel_info_offset(s) + )), + ('write_channel_info_length', IntField( + size=2, + default=lambda s: len(s['buffer_channel_info']) + )), + ('flags', FlagField( + size=4, + flag_type=WriteFlags + )), + ('buffer', BytesField( + size=lambda s: s['length'].get_value() + )), + ('buffer_channel_info', BytesField( + size=lambda s: s['write_channel_info_length'].get_value() + )) + ]) + super(SMB2WriteRequest, self).__init__() + + def _get_write_channel_info_offset(self, structure): + if len(structure['buffer_channel_info']) == 0: + return 0 + else: + header_size = 64 + packet_size = structure['structure_size'].get_value() - 1 + buffer_size = len(structure['buffer']) + return header_size + packet_size + buffer_size + + +class SMB2WriteResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.22 SMB2 WRITE Response + The response to the SMB2 WRITE Request sent by the server + """ + COMMAND = Commands.SMB2_WRITE + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=17 + )), + ('reserved', IntField(size=2)), + ('count', IntField(size=4)), + ('remaining', IntField(size=4)), + ('write_channel_info_offset', IntField(size=2)), + ('write_channel_info_length', IntField(size=2)) + ]) + super(SMB2WriteResponse, self).__init__() + + +class SMB2QueryDirectoryRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.33 QUERY_DIRECTORY Request + Used by the client to obtain a directory enumeration on a directory open. + """ + COMMAND = Commands.SMB2_QUERY_DIRECTORY + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=33 + )), + ('file_information_class', EnumField( + size=1, + enum_type=FileInformationClass + )), + ('flags', FlagField( + size=1, + flag_type=QueryDirectoryFlags + )), + ('file_index', IntField(size=4)), + ('file_id', BytesField(size=16)), + ('file_name_offset', IntField( + size=2, + default=lambda s: 0 if len(s['buffer']) == 0 else 96 + )), + ('file_name_length', IntField( + size=2, + default=lambda s: len(s['buffer']) + )), + ('output_buffer_length', IntField(size=4)), + # UTF-16-LE encoded search pattern + ('buffer', BytesField( + size=lambda s: s['file_name_length'].get_value() + )) + ]) + super(SMB2QueryDirectoryRequest, self).__init__() + + @staticmethod + def unpack_response(file_information_class, buffer): + """ + Pass in the buffer value from the response object to unpack it and + return a list of query response structures for the request. + + :param buffer: The raw bytes value of the SMB2QueryDirectoryResponse + buffer field. + :return: List of query_info.* structures based on the + FileInformationClass used in the initial query request. + """ + structs = smbprotocol.query_info + resp_structure = { + FileInformationClass.FILE_DIRECTORY_INFORMATION: + structs.FileDirectoryInformation, + FileInformationClass.FILE_NAMES_INFORMATION: + structs.FileNamesInformation, + FileInformationClass.FILE_BOTH_DIRECTORY_INFORMATION: + structs.FileBothDirectoryInformation, + FileInformationClass.FILE_ID_BOTH_DIRECTORY_INFORMATION: + structs.FileIdBothDirectoryInformation, + FileInformationClass.FILE_FULL_DIRECTORY_INFORMATION: + structs.FileFullDirectoryInformation, + FileInformationClass.FILE_ID_FULL_DIRECTORY_INFORMATION: + structs.FileIdFullDirectoryInformation, + }[file_information_class] + query_results = [] + + current_offset = 0 + is_next = True + while is_next: + result = resp_structure() + result.unpack(buffer[current_offset:]) + query_results.append(result) + current_offset += result['next_entry_offset'].get_value() + is_next = result['next_entry_offset'].get_value() != 0 + + return query_results + + +class SMB2QueryDirectoryResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.34 SMB2 QUERY_DIRECTORY Response + Response to an SMB2 QUERY_DIRECTORY Request. + """ + + COMMAND = Commands.SMB2_QUERY_DIRECTORY + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=9 + )), + ('output_buffer_offset', IntField( + size=2, + default=72 + )), + ('output_buffer_length', IntField( + size=4, + default=lambda s: len(s['buffer']) + )), + # this structure varies based on the requested information class + ('buffer', BytesField( + size=lambda s: s['output_buffer_length'].get_value() + )) + ]) + super(SMB2QueryDirectoryResponse, self).__init__() + + +class Open(object): + + def __init__(self, tree, name): + """ + [MS-SMB2] v53.0 2017-09-15 + + 3.2.1.6 Per Application Open of a File + Attributes per each open of a file. A file can be a File, Pipe, + Directory, or Printer + + :param tree: The Tree (share) the file is located in. + :param name: The name of the file, excluding the share path, e.g. + \\server\share\folder\file.txt would be folder\file.txt + """ + # properties available based on the file itself + self._connected = False + self.creation_time = None + self.last_access_time = None + self.last_write_time = None + self.change_time = None + self.allocation_size = None + self.end_of_file = None + self.file_attributes = None + + # properties used privately + self.file_id = None + self.tree_connect = tree + self.connection = tree.session.connection + self.oplock_level = None + self.durable = None + self.file_name = name + self.resilient_handle = None + self.last_disconnect_time = None + self.resilient_timeout = None + + # an array of entries used to maintain information about outstanding + # lock and unlock operations performed on resilient Opens. Contains + # sequence_number - 4-bit integer modulo 16 + # free - boolean value where False is no outstanding requests + self.operation_buckets = [] + + # SMB 3.x+ + self.durable_timeout = None + + # Table of outstanding requests, lookup by Request.cancel_id, + # message_id + self.outstanding_requests = {} + + self.create_guid = None + self.is_persistent = None + self.desired_access = None + self.share_mode = None + self.create_options = None + self.file_attributes = None + self.create_disposition = None + + def open(self, impersonation_level, desired_access, file_attributes, + share_access, create_disposition, create_options, + create_contexts=None): + """ + This will open the file based on the input parameters supplied. Any + file open should also be called with Open.close() when it is finished. + + More details on how each option affects the open process can be found + here https://msdn.microsoft.com/en-us/library/cc246502.aspx. + + :param impersonation_level: (ImpersonationLevel) The type of + impersonation level that is issuing the create request. + :param desired_access: The level of access that is required of the + open. FilePipePrinterAccessMask or DirectoryAccessMask should be + used depending on the type of file being opened. + :param file_attributes: (FileAttributes) attributes to set on the file + being opened, this usually is for opens that creates a file. + :param share_access: (ShareAccess) Specifies the sharing mode for the + open. + :param create_disposition: (CreateDisposition) Defines the action the + server MUST take if the file already exists. + :param create_options: (CreateOptions) Specifies the options to be + applied when creating or opening the file. + :param create_contexts: (List) List of + context request values to be applied to the create. + + Create Contexts are used to encode additional flags and attributes when + opening files. More details on create context request values can be + found here https://msdn.microsoft.com/en-us/library/cc246504.aspx. + + :return: List of context response values or None if there are no + context response values. If the context response value is not known + to smbprotocol then the list value would be raw bytes otherwise + it is a Structure defined in create_contexts.py + """ + create = SMB2CreateRequest() + create['impersonation_level'] = impersonation_level + create['desired_access'] = desired_access + create['file_attributes'] = file_attributes + create['share_access'] = share_access + create['create_disposition'] = create_disposition + create['create_options'] = create_options + create['buffer_path'] = self.file_name.encode('utf-16-le') + if create_contexts: + create['buffer_contexts'] = smbprotocol.create_contexts.\ + SMB2CreateContextRequest.pack_multiple(create_contexts) + + log.info("Session: %s, Tree Connect: %s - sending SMB2 Create Request " + "for file %s" % (self.tree_connect.session.username, + self.tree_connect.share_name, + self.file_name)) + log.debug(str(create)) + request = self.connection.send(create, + self.tree_connect.session.session_id, + self.tree_connect.tree_connect_id) + + log.info("Session: %s, Tree Connect: %s - receiving SMB2 Create " + "Response" % (self.tree_connect.session.username, + self.tree_connect.share_name)) + response = self.connection.receive(request) + create_response = SMB2CreateResponse() + create_response.unpack(response['data'].get_value()) + self._connected = True + log.debug(str(create_response)) + + self.file_id = create_response['file_id'].get_value() + self.tree_connect.session.open_table[self.file_id] = self + self.oplock_level = create_response['oplock_level'].get_value() + self.durable = False + self.resilient_handle = False + self.last_disconnect_time = 0 + + if self.connection.dialect >= Dialects.SMB_3_0_0: + self.desired_access = desired_access + self.share_mode = share_access + self.create_options = create_options + self.file_attributes = file_attributes + self.create_disposition = create_disposition + + self.creation_time = create_response['creation_time'].get_value() + self.last_access_time = create_response['last_access_time'].get_value() + self.last_write_time = create_response['last_write_time'].get_value() + self.change_time = create_response['change_time'].get_value() + self.allocation_size = create_response['allocation_size'].get_value() + self.end_of_file = create_response['end_of_file'].get_value() + self.file_attributes = create_response['file_attributes'].get_value() + + create_contexts_response = None + if create_response['create_contexts_length'].get_value() > 0: + create_contexts_response = [] + for context in create_response['buffer'].get_value(): + create_contexts_response.append(context.get_context_data()) + + return create_contexts_response + + def read(self, offset, length, min_length=0, unbuffered=False, wait=False, + send=True): + """ + Reads from an opened file or pipe + + Supports out of band send function, call this function with send=False + to return a tuple of (SMB2ReadRequest, receive_func) instead of + sending the the request and waiting for the response. The receive_func + can be used to get the response from the server by passing in the + Request that was used to sent it out of band. + + :param offset: The offset to start the read of the file. + :param length: The number of bytes to read from the offset. + :param min_length: The minimum number of bytes to be read for a + successful operation. + :param unbuffered: Whether to the server should cache the read data at + intermediate layers, only value for SMB 3.0.2 or newer + :param wait: If send=True, whether to wait for a response if + STATUS_PENDING was received from the server or fail. + :param send: Whether to send the request in the same call or return the + message to the caller and the unpack function + :return: A byte string of the bytes read + """ + if length > self.connection.max_read_size: + raise SMBException("The requested read length %d is greater than " + "the maximum negotiated read size %d" + % (length, self.connection.max_read_size)) + + read = SMB2ReadRequest() + read['length'] = length + read['offset'] = offset + read['minimum_count'] = min_length + read['file_id'] = self.file_id + read['padding'] = b"\x50" + + if unbuffered: + if self.connection.dialect < Dialects.SMB_3_0_2: + raise SMBUnsupportedFeature(self.connection.dialect, + Dialects.SMB_3_0_2, + "SMB2_READFLAG_READ_UNBUFFERED", + True) + read['flags'].set_flag(ReadFlags.SMB2_READFLAG_READ_UNBUFFERED) + + if not send: + return read, self._read_response + + log.info("Session: %s, Tree Connect ID: %s - sending SMB2 Read " + "Request for file %s" % (self.tree_connect.session.username, + self.tree_connect.share_name, + self.file_name)) + log.debug(str(read)) + request = self.connection.send(read, + self.tree_connect.session.session_id, + self.tree_connect.tree_connect_id) + return self._read_response(request, wait) + + def _read_response(self, request, wait=False): + log.info("Session: %s, Tree Connect ID: %s - receiving SMB2 Read " + "Response" % (self.tree_connect.session.username, + self.tree_connect.share_name)) + response = self._get_read_write_response(request, wait) + read_response = SMB2ReadResponse() + read_response.unpack(response['data'].get_value()) + log.debug(str(read_response)) + + return read_response['buffer'].get_value() + + def write(self, data, offset=0, write_through=False, unbuffered=False, + wait=False, send=True): + """ + Writes data to an opened file. + + Supports out of band send function, call this function with send=False + to return a tuple of (SMBWriteRequest, receive_func) instead of + sending the the request and waiting for the response. The receive_func + can be used to get the response from the server by passing in the + Request that was used to sent it out of band. + + :param data: The bytes data to write. + :param offset: The offset in the file to write the bytes at + :param write_through: Whether written data is persisted to the + underlying storage, not valid for SMB 2.0.2. + :param unbuffered: Whether to the server should cache the write data at + intermediate layers, only value for SMB 3.0.2 or newer + :param wait: If send=True, whether to wait for a response if + STATUS_PENDING was received from the server or fail. + :param send: Whether to send the request in the same call or return the + message to the caller and the unpack function + :return: The number of bytes written + """ + data_len = len(data) + if data_len > self.connection.max_write_size: + raise SMBException("The requested write length %d is greater than " + "the maximum negotiated write size %d" + % (data_len, self.connection.max_write_size)) + + write = SMB2WriteRequest() + write['length'] = len(data) + write['offset'] = offset + write['file_id'] = self.file_id + write['buffer'] = data + + if write_through: + if self.connection.dialect < Dialects.SMB_2_1_0: + raise SMBUnsupportedFeature(self.connection.dialect, + Dialects.SMB_2_1_0, + "SMB2_WRITEFLAG_WRITE_THROUGH", + True) + write['flags'].set_flag(WriteFlags.SMB2_WRITEFLAG_WRITE_THROUGH) + + if unbuffered: + if self.connection.dialect < Dialects.SMB_3_0_2: + raise SMBUnsupportedFeature(self.connection.dialect, + Dialects.SMB_3_0_2, + "SMB2_WRITEFLAG_WRITE_UNBUFFERED", + True) + write['flags'].set_flag(WriteFlags.SMB2_WRITEFLAG_WRITE_UNBUFFERED) + + if not send: + return write, self._write_response + + log.info("Session: %s, Tree Connect: %s - sending SMB2 Write Request " + "for file %s" % (self.tree_connect.session.username, + self.tree_connect.share_name, + self.file_name)) + log.debug(str(write)) + request = self.connection.send(write, + self.tree_connect.session.session_id, + self.tree_connect.tree_connect_id) + return self._write_response(request, wait) + + def _write_response(self, request, wait=False): + log.info("Session: %s, Tree Connect: %s - receiving SMB2 Write " + "Response" % (self.tree_connect.session.username, + self.tree_connect.share_name)) + response = self._get_read_write_response(request, wait) + write_response = SMB2WriteResponse() + write_response.unpack(response['data'].get_value()) + log.debug(str(write_response)) + + return write_response['count'].get_value() + + def flush(self, send=True): + """ + A command sent by the client to request that a server flush all cached + file information for the opened file. + + Supports out of band send function, call this function with send=False + to return a tuple of (SMB2FlushRequest, receive_func) instead of + sending the the request and waiting for the response. The receive_func + can be used to get the response from the server by passing in the + Request that was used to sent it out of band. + + :param send: Whether to send the request in the same call or return the + message to the caller and the unpack function + :return: The SMB2FlushResponse received from the server + """ + flush = SMB2FlushRequest() + flush['file_id'] = self.file_id + + if not send: + return flush, self._flush_response + + log.info("Session: %s, Tree Connect: %s - sending SMB2 Flush Request " + "for file %s" % (self.tree_connect.session.username, + self.tree_connect.share_name, + self.file_name)) + log.debug(str(flush)) + request = self.connection.send(flush, + self.tree_connect.session.session_id, + self.tree_connect.tree_connect_id) + return self._flush_response(request) + + def _flush_response(self, request): + log.info("Session: %s, Tree Connect: %s - receiving SMB2 Flush " + "Response" % (self.tree_connect.session.username, + self.tree_connect.share_name)) + response = self.connection.receive(request) + flush_response = SMB2FlushResponse() + flush_response.unpack(response['data'].get_value()) + log.debug(str(flush_response)) + return flush_response + + def query_directory(self, pattern, file_information_class, flags=None, + file_index=0, max_output=65536, send=True): + """ + Run a Query/Find on an opened directory based on the params passed in. + + Supports out of band send function, call this function with send=False + to return a tuple of (SMB2QueryDirectoryRequest, receive_func) instead + of sending the the request and waiting for the response. The + receive_func can be used to get the response from the server by passing + in the Request that was used to sent it out of band. + + :param pattern: The string pattern to use for the query, this pattern + format is based on the SMB server but * is usually a wildcard + :param file_information_class: FileInformationClass that defines the + format of the result that is returned + :param flags: QueryDirectoryFlags that control how the operation must + be processed. + :param file_index: If the flags SMB2_INDEX_SPECIFIED, this is the index + the query should resume on, otherwise should be 0 + :param max_output: The maximum output size, defaulted to the max credit + size but can be increased to reduced round trip operations. + :param send: Whether to send the request in the same call or return the + message to the caller and the unpack function + :return: A list of structures defined in query_info.py, the list entry + structure is based on the value of file_information_class in the + request message + """ + query = SMB2QueryDirectoryRequest() + query['file_information_class'] = file_information_class + query['flags'] = flags + query['file_index'] = file_index + query['file_id'] = self.file_id + query['output_buffer_length'] = max_output + query['buffer'] = pattern.encode('utf-16-le') + + if not send: + return query, self._query_directory_response + + log.info("Session: %s, Tree Connect: %s - sending SMB2 Query " + "Directory Request for directory %s" + % (self.tree_connect.session.username, + self.tree_connect.share_name, self.file_name)) + log.debug(str(query)) + request = self.connection.send(query, + self.tree_connect.session.session_id, + self.tree_connect.tree_connect_id) + return self._query_directory_response(request) + + def _query_directory_response(self, request): + log.info("Session: %s, Tree Connect: %s - receiving SMB2 Query " + "Response" % (self.tree_connect.session.username, + self.tree_connect.share_name)) + response = self.connection.receive(request) + query_response = SMB2QueryDirectoryResponse() + query_response.unpack(response['data'].get_value()) + log.debug(str(query_response)) + + query_request = SMB2QueryDirectoryRequest() + query_request.unpack(request.message['data'].get_value()) + file_cl = query_request['file_information_class'].get_value() + data = query_response['buffer'].get_value() + results = SMB2QueryDirectoryRequest.unpack_response(file_cl, data) + return results + + def close(self, get_attributes=False, send=True): + """ + Closes an opened file. + + Supports out of band send function, call this function with send=False + to return a tuple of (SMB2CloseRequest, receive_func) instead of + sending the the request and waiting for the response. The receive_func + can be used to get the response from the server by passing in the + Request that was used to sent it out of band. + + :param get_attributes: (Bool) whether to get the latest attributes on + the close and set them on the Open object + :param send: Whether to send the request in the same call or return the + message to the caller and the unpack function + :return: SMB2CloseResponse message received from the server + """ + if not self._connected: + return + + close = SMB2CloseRequest() + + close['file_id'] = self.file_id + if get_attributes: + close['flags'] = CloseFlags.SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB + + if not send: + return close, self._close_response + + log.info("Session: %s, Tree Connect: %s - sending SMB2 Close Request " + "for file %s" % (self.tree_connect.session.username, + self.tree_connect.share_name, + self.file_name)) + log.debug(str(close)) + request = self.connection.send(close, + self.tree_connect.session.session_id, + self.tree_connect.tree_connect_id) + return self._close_response(request) + + def _close_response(self, request): + log.info("Session: %s, Tree Connect: %s - receiving SMB2 Close " + "Response" % (self.tree_connect.session.username, + self.tree_connect.share_name)) + try: + response = self.connection.receive(request) + except SMBResponseException as exc: + # check if it was already closed + if exc.status == NtStatus.STATUS_FILE_CLOSED: + self._connected = False + self.tree_connect.session.open_table.pop(self.file_id, None) + return + # else raise the exception + raise exc + + c_resp = SMB2CloseResponse() + c_resp.unpack(response['data'].get_value()) + log.debug(str(c_resp)) + self._connected = False + del self.tree_connect.session.open_table[self.file_id] + + # update the attributes if requested + close_request = SMB2CloseRequest() + close_request.unpack(request.message['data'].get_value()) + if close_request['flags'].has_flag( + CloseFlags.SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB): + self.creation_time = c_resp['creation_time'].get_value() + self.last_access_time = c_resp['last_access_time'].get_value() + self.last_write_time = c_resp['last_write_time'].get_value() + self.change_time = c_resp['change_time'].get_value() + self.allocation_size = c_resp['allocation_size'].get_value() + self.end_of_file = c_resp['end_of_file'].get_value() + self.file_attributes = c_resp['file_attributes'].get_value() + return c_resp + + def _get_read_write_response(self, request, wait=False): + # used by read and write to handle STATUS_PENDING on a read/write + # request + while True: + try: + response = self.connection.receive(request) + except SMBResponseException as exc: + if not wait or exc.status != NtStatus.STATUS_PENDING: + raise exc + else: + pass + else: + break + return response diff --git a/smbprotocol/query_info.py b/smbprotocol/query_info.py new file mode 100644 index 00000000..7c16c603 --- /dev/null +++ b/smbprotocol/query_info.py @@ -0,0 +1,217 @@ +import smbprotocol.open +from smbprotocol.structure import BytesField, DateTimeField, \ + FlagField, IntField, Structure + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + + +class FileBothDirectoryInformation(Structure): + """ + [MS-FSCC] 2.4.8 FileBothDirectoryInformation + https://msdn.microsoft.com/en-us/library/cc232095.aspx + """ + + def __init__(self): + self.fields = OrderedDict([ + ('next_entry_offset', IntField(size=4)), + ('file_index', IntField(size=4)), + ('creation_time', DateTimeField(size=8)), + ('last_access_time', DateTimeField(size=8)), + ('last_write_time', DateTimeField(size=8)), + ('change_time', DateTimeField(size=8)), + ('end_of_file', IntField(size=8)), + ('allocation_size', IntField(size=8)), + ('file_attributes', FlagField( + size=4, + flag_type=smbprotocol.open.FileAttributes + )), + ('file_name_length', IntField( + size=4, + default=lambda s: len(s['file_name']) + )), + ('ea_size', IntField(size=4)), + ('short_name_length', IntField( + size=1, + default=lambda s: len(s['short_name']) + )), + ('reserved', IntField(size=1)), + ('short_name', BytesField( + size=lambda s: s['short_name_length'].get_value() + )), + ('short_name_padding', BytesField( + size=lambda s: 24 - len(s['short_name']), + default=lambda s: b"\x00" * (24 - len(s['short_name'])) + )), + ('file_name', BytesField( + size=lambda s: s['file_name_length'].get_value() + )) + ]) + super(FileBothDirectoryInformation, self).__init__() + + +class FileDirectoryInformation(Structure): + """ + [MS-FSCC] 2.4.10 FileDirectoryInformation + https://msdn.microsoft.com/en-us/library/cc232097.aspx + """ + + def __init__(self): + self.fields = OrderedDict([ + ('next_entry_offset', IntField(size=4)), + ('file_index', IntField(size=4)), + ('creation_time', DateTimeField(size=8)), + ('last_access_time', DateTimeField(size=8)), + ('last_write_time', DateTimeField(size=8)), + ('change_time', DateTimeField(size=8)), + ('end_of_file', IntField(size=8)), + ('allocation_size', IntField(size=8)), + ('file_attributes', FlagField( + size=4, + flag_type=smbprotocol.open.FileAttributes + )), + ('file_name_length', IntField( + size=4, + default=lambda s: len(s['file_name']) + )), + ('file_name', BytesField( + size=lambda s: s['file_name_length'].get_value() + )) + ]) + super(FileDirectoryInformation, self).__init__() + + +class FileFullDirectoryInformation(Structure): + """ + [MS-FSCC] 2.4.14 FileFullDirectoryInformation + https://msdn.microsoft.com/en-us/library/cc232068.aspx + """ + + def __init__(self): + self.fields = OrderedDict([ + ('next_entry_offset', IntField(size=4)), + ('file_index', IntField(size=4)), + ('creation_time', DateTimeField(size=8)), + ('last_access_time', DateTimeField(size=8)), + ('last_write_time', DateTimeField(size=8)), + ('change_time', DateTimeField(size=8)), + ('end_of_file', IntField(size=8)), + ('allocation_size', IntField(size=8)), + ('file_attributes', FlagField( + size=4, + flag_type=smbprotocol.open.FileAttributes + )), + ('file_name_length', IntField( + size=4, + default=lambda s: len(s['file_name']) + )), + ('ea_size', IntField(size=4)), + ('file_name', BytesField( + size=lambda s: s['file_name_length'].get_value() + )) + ]) + super(FileFullDirectoryInformation, self).__init__() + + +class FileIdBothDirectoryInformation(Structure): + """ + [MS-FSCC] 2.4.17 FileIdBothDirectoryInformation + https://msdn.microsoft.com/en-us/library/cc232070.aspx + """ + + def __init__(self): + self.fields = OrderedDict([ + ('next_entry_offset', IntField(size=4)), + ('file_index', IntField(size=4)), + ('creation_time', DateTimeField(size=8)), + ('last_access_time', DateTimeField(size=8)), + ('last_write_time', DateTimeField(size=8)), + ('change_time', DateTimeField(size=8)), + ('end_of_file', IntField(size=8)), + ('allocation_size', IntField(size=8)), + ('file_attributes', FlagField( + size=4, + flag_type=smbprotocol.open.FileAttributes + )), + ('file_name_length', IntField( + size=4, + default=lambda s: len(s['file_name']) + )), + ('ea_size', IntField(size=4)), + ('short_name_length', IntField( + size=1, + default=lambda s: len(s['short_name']) + )), + ('reserved1', IntField(size=1)), + ('short_name', BytesField( + size=lambda s: s['short_name_length'].get_value() + )), + ('short_name_padding', BytesField( + size=lambda s: 24 - len(s['short_name']), + default=lambda s: b"\x00" * (24 - len(s['short_name'])) + )), + ('reserved2', IntField(size=2)), + ('file_id', IntField(size=8)), + ('file_name', BytesField( + size=lambda s: s['file_name_length'].get_value() + )) + ]) + super(FileIdBothDirectoryInformation, self).__init__() + + +class FileIdFullDirectoryInformation(Structure): + """ + [MS-FSCC] 2.4.18 FileIdFullDirectoryInformation + https://msdn.microsoft.com/en-us/library/cc232071.aspx + """ + + def __init__(self): + self.fields = OrderedDict([ + ('next_entry_offset', IntField(size=4)), + ('file_index', IntField(size=4)), + ('creation_time', DateTimeField(size=8)), + ('last_access_time', DateTimeField(size=8)), + ('last_write_time', DateTimeField(size=8)), + ('change_time', DateTimeField(size=8)), + ('end_of_file', IntField(size=8)), + ('allocation_size', IntField(size=8)), + ('file_attributes', FlagField( + size=4, + flag_type=smbprotocol.open.FileAttributes + )), + ('file_name_length', IntField( + size=4, + default=lambda s: len(s['file_name']) + )), + ('ea_size', IntField(size=4)), + ('reserved', IntField(size=4)), + ('file_id', IntField(size=8)), + ('file_name', BytesField( + size=lambda s: s['file_name_length'].get_value() + )) + ]) + super(FileIdFullDirectoryInformation, self).__init__() + + +class FileNamesInformation(Structure): + """ + [MS-FSCC] 2.4.26 FileNamesInformation + https://msdn.microsoft.com/en-us/library/cc232077.aspx + """ + + def __init__(self): + self.fields = OrderedDict([ + ('next_entry_offset', IntField(size=4)), + ('file_index', IntField(size=4)), + ('file_name_length', IntField( + size=4, + default=lambda s: len(s['file_name']) + )), + ('file_name', BytesField( + size=lambda s: s['file_name_length'].get_value() + )) + + ]) + super(FileNamesInformation, self).__init__() diff --git a/smbprotocol/security_descriptor.py b/smbprotocol/security_descriptor.py new file mode 100644 index 00000000..81a3bf56 --- /dev/null +++ b/smbprotocol/security_descriptor.py @@ -0,0 +1,444 @@ +import struct + +from smbprotocol.structure import BytesField, EnumField, FlagField, \ + IntField, ListField, Structure, StructureField + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + + +class AccessMask(object): + """ + [MS-DTYP] + + 2.4.3 ACCESS_MASK + 32-bit set of flags that are used to encode the user rights to an object. + This is just a generic setup of access mask flags to set an can vary + from the object being set. When setting the AccessMask on an ACE packet, + any 32-bit value can be used and this is just as a guideline. + """ + GENERIC_READ = 0x80000000 + GENERIC_WRITE = 0x40000000 + GENERIC_EXECUTE = 0x20000000 + GENERIC_ALL = 0x10000000 + MAXIMUM_ALLOWED = 0x02000000 + ACCESS_SYSTEM_SECURITY = 0x01000000 + SYNCHRONIZE = 0x00100000 + WRITE_OWNER = 0x00080000 + WRITE_DACL = 0x00040000 + READ_CONTROL = 0x00020000 + DELETE = 0x00010000 + + +class AceType(object): + """ + [MS-DTYP] + + 2.4.4.1 ACE_HEADER AceType + The type of ACE in the ACE packet. + """ + # Current only have structures for the first 3 + ACCESS_ALLOWED_ACE_TYPE = 0x00 + ACCESS_DENIED_ACE_TYPE = 0x01 + SYSTEM_AUDIT_ACE_TYPE = 0x02 + + # No structures are defined for the below + SYSTEM_ALARM_ACE_TYPE = 0x03 + ACCESS_ALLOWED_COMPOUND_ACE_TYPE = 0x04 + ACCESS_ALLOWED_OBJECT_ACE_TYPE = 0x05 + ACCESS_DENIED_OBJECT_ACE_TYPE = 0x06 + SYSTEM_AUDIT_OBJECT_ACE_TYPE = 0x07 + SYSTEM_ALARM_OBJECT_ACE_TYPE = 0x08 + ACCESS_ALLOWED_CALLBACK_ACE_TYPE = 0x09 + ACCESS_DENIED_CALLBACK_ACE_TYPE = 0x0a + ACCESS_ALLOWED_CALLBACK_OBJECT_ACE_TYPE = 0x0b + ACCESS_DENIED_CALLBACK_OBJECT_ACE_TYPE = 0x0c + SYSTEM_AUDIT_CALLBACK_ACE_TYPE = 0x0d + SYSTEM_ALARM_CALLBACK_ACE_TYPE = 0x0e + SYSTEM_AUDIT_CALLBACK_OBJECT_ACE_TYPE = 0x0f + SYSTEM_ALARM_CALLBACK_OBJECT_ACE_TYPE = 0x10 + SYSTEM_MANDATORY_LABEL_ACE_TYPE = 0x11 + SYSTEM_RESOURCE_ATTRIBUTE_ACE_TYPE = 0x12 + SYSTEM_SCOPED_POLICY_ID_ACE_TYPE = 0x13 + + +class AceFlags(object): + """ + [MS-DTYP] + + 2.4.4.1 ACE_HEADER AceFlags + Controls the ACE specified in the ACE packet. + """ + CONTAINER_INHERIT_ACE = 0x02 + FAILED_ACCESS_ACE_FLAG = 0x80 + INHERIT_ONLY_ACE = 0x08 + INHERITED_ACE = 0x10 + NO_PROPAGATE_INHERITY_ACE = 0x04 + OBJECT_INHERIT_ACE = 0x01 + SUCCESSFUL_ACCESS_ACE_FLAG = 0x40 + + +class AclRevision(object): + """ + [MS-DTYP] + + 2.4.5 ACL AclRevision + ACL_REVISION - AceType 0, 1, 2, 3, 11, 12, 13 are valid + ACL_REVISION_DS - AceType 5, 6, 7, 8, 11 are valid (Directory Service) + """ + ACL_REVISION = 0x02 + ACL_REVISION_DS = 0x04 # not natively supported yet + + +class SDControl(object): + """ + [MS-DTYP] + + 2.4.6 SECURITY_DESCRIPTOR Control + Specifies control access bit flags. + """ + SELF_RELATIVE = 0x8000 + RM_CONTROL_VALID = 0x4000 + SACL_PROTECTED = 0x2000 + DACL_PROTECTED = 0x1000 + SACL_AUTO_INHERITED = 0x0800 + DACL_AUTO_INHERITED = 0x0400 + SACL_COMPUTED_INHERITANCE_REQUIRED = 0x0200 + DACL_COMPUTED_INHERITANCE_REQUIRED = 0x0100 + SERVER_SECURITY = 0x0080 + DACL_TRUSTED = 0x0040 + SACL_DEFAULTED = 0x0020 + SACL_PRESENT = 0x0010 + DACL_DEFAULTED = 0x0008 + DACL_PRESENT = 0x0004 + GROUP_DEFAULTED = 0x0002 + OWNER_DEFAULTED = 0x0001 + NONE = 0x0000 + + +class SIDPacket(Structure): + """ + [MS-DTYP] 2.4.2.2 SID--Packet Representation + + The packet representation of the SID type for use by block protocols. While + the values can be set explicitly, it may be easier to use the from_string + function store the byte structure from the string format. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('revision', IntField( + size=1, + default=1 + )), + ('sub_authority_count', IntField( + size=1, + default=lambda s: len(s['sub_authorities'].get_value()) + )), + ('reserved', IntField(size=2)), + ('identifier_authority', IntField( + size=4, + little_endian=False + )), + ('sub_authorities', ListField( + list_type=IntField(size=4), + list_count=lambda s: s['sub_authority_count'].get_value() + )) + ]) + super(SIDPacket, self).__init__() + + def __str__(self): + revision = self['revision'].get_value() + id_authority = self['identifier_authority'].get_value() + sub_authorities = self['sub_authorities'].get_value() + sid_string = "S-%d-%d-%s" % (revision, id_authority, + "-".join(str(x) for x in sub_authorities)) + return sid_string + + def from_string(self, sid_string): + """ + Used to set the structure parameters based on the input string + + :param sid_string: String of the sid in S-x-x-x-x form + """ + if not sid_string.startswith("S-"): + raise ValueError("A SID string must start with S-") + + sid_entries = sid_string.split("-") + if len(sid_entries) < 3: + raise ValueError("A SID string must start with S and contain a " + "revision and identifier authority, e.g. S-1-0") + + revision = int(sid_entries[1]) + id_authority = int(sid_entries[2]) + sub_authorities = [int(i) for i in sid_entries[3:]] + + self['revision'].set_value(revision) + self['identifier_authority'].set_value(id_authority) + self['sub_authorities'] = sub_authorities + + +class AccessAllowedAce(Structure): + """ + [MS-DTYP] 2.4.4.3 ACCESS_ALLOWED_ACE + + Used for the DACL that controls access to an object. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('ace_type', EnumField( + size=1, + default=AceType.ACCESS_ALLOWED_ACE_TYPE, + enum_type=AceType + )), + ('ace_flags', FlagField( + size=1, + flag_type=AceFlags + )), + ('ace_size', IntField( + size=2, + default=lambda s: 8 + len(s['sid']) + )), + ('mask', FlagField( + size=4, + flag_type=AccessMask, + flag_strict=False + )), + ('sid', StructureField( + structure_type=SIDPacket + )) + ]) + super(AccessAllowedAce, self).__init__() + + +class AccessDeniedAce(Structure): + """ + [MS-DTYP] 2.4.4.4 ACCESS_DENIED_ACE + + Used for the DACL that controls denies to an object. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('ace_type', EnumField( + size=1, + default=AceType.ACCESS_DENIED_ACE_TYPE, + enum_type=AceType + )), + ('ace_flags', FlagField( + size=1, + flag_type=AceFlags + )), + ('ace_size', IntField( + size=2, + default=lambda s: 8 + len(s['sid']) + )), + ('mask', FlagField( + size=4, + flag_type=AccessMask, + flag_strict=False + )), + ('sid', StructureField( + structure_type=SIDPacket + )) + ]) + super(AccessDeniedAce, self).__init__() + + +class SystemAuditAce(Structure): + """ + [MS-DTYP] 2.4.4.10 SYSTEM_AUDIT_ACE + + Used for the SACL that specifies what types of access cause system-level + notifications. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('ace_type', EnumField( + size=1, + default=AceType.SYSTEM_AUDIT_ACE_TYPE, + enum_type=AceType + )), + ('ace_flags', FlagField( + size=1, + flag_type=AceFlags + )), + ('ace_size', IntField( + size=2, + default=lambda s: 8 + len(s['sid']) + )), + ('mask', FlagField( + size=4, + flag_type=AccessMask, + flag_strict=False + )), + ('sid', StructureField( + structure_type=SIDPacket + )) + ]) + super(SystemAuditAce, self).__init__() + + +class AclPacket(Structure): + """ + [MS-DTYP] 2.4.5 ACL + + Access Control List packet is used to specify a list of individual ACEs. + An ACL is said to be in canonical form if: + All explicit ACEs are placed before inherited ACEs + Within the explicit ACEs, deny ACEs come before grant ACEs + Deny ACEs on the object come before deny ACEs on a child or property + Grant ACEs on the object come before grant ACEs on a child or property + Inherited ACEs are placed in the order in which they were inherited + """ + + def __init__(self): + self.fields = OrderedDict([ + ('acl_revision', EnumField( + size=1, + default=AclRevision.ACL_REVISION, + enum_type=AclRevision + )), + ('sbz1', IntField(size=1)), + ('acl_size', IntField( + size=2, + default=lambda s: 8 + len(s['aces']) + )), + ('ace_count', IntField( + size=2, + default=lambda s: len(s['aces'].get_value()) + )), + ('sbz2', IntField(size=2)), + ('aces', ListField( + list_count=lambda s: s['ace_count'].get_value(), + unpack_func=lambda s, d: self._unpack_aces(s, d) + )) + ]) + super(AclPacket, self).__init__() + + def _unpack_aces(self, structure, data): + aces = [] + while data != b"": + ace_type = struct.unpack("= Dialects.SMB_3_1_1: + preauth_hash = b"\x00" * 64 + hash_al = self.connection.preauth_integrity_hash_id + for message in self.preauth_integrity_hash_value: + preauth_hash = hash_al(preauth_hash + message.pack()).digest() + + self.signing_key = self._smb3kdf(self.session_key, + b"SMBSigningKey\x00", + preauth_hash) + self.application_key = self._smb3kdf(self.session_key, + b"SMBAppKey\x00", + preauth_hash) + self.encryption_key = self._smb3kdf(self.session_key, + b"SMBC2SCipherKey\x00", + preauth_hash) + self.decryption_key = self._smb3kdf(self.session_key, + b"SMBS2CCipherKey\x00", + preauth_hash) + elif self.connection.dialect >= Dialects.SMB_3_0_0: + self.signing_key = self._smb3kdf(self.session_key, + b"SMB2AESCMAC\x00", + b"SmbSign\x00") + self.application_key = self._smb3kdf(self.session_key, + b"SMB2APP\x00", b"SmbRpc\x00") + self.encryption_key = self._smb3kdf(self.session_key, + b"SMB2AESCCM\x00", + b"ServerIn \x00") + self.decryption_key = self._smb3kdf(self.session_key, + b"SMB2AESCCM\x00", + b"ServerOut\x00") + else: + self.signing_key = self.session_key + self.application_key = self.session_key + + flags = setup_response['session_flags'] + if flags.has_flag(SessionFlags.SMB2_SESSION_FLAG_ENCRYPT_DATA) or \ + self.require_encryption: + # make sure the connection actually supports encryption + if not self.connection.supports_encryption: + raise SMBException("SMB encryption is required but the " + "connection does not support it") + self.encrypt_data = True + self.signing_required = False # encryption covers signing + else: + self.encrypt_data = False + + if flags.has_flag(SessionFlags.SMB2_SESSION_FLAG_IS_GUEST) or \ + flags.has_flag(SessionFlags.SMB2_SESSION_FLAG_IS_NULL): + self.session_key = None + self.signing_key = None + self.application_key = None + self.encryption_key = None + self.decryption_key = None + if self.signing_required or self.encrypt_data: + self.session_id = None + raise SMBException("SMB encryption or signing was required " + "but session was authenticated as a guest " + "which does not support encryption or " + "signing") + + if self.signing_required: + log.info("Verifying the SMB Setup Session signature as auth is " + "successful") + self.connection._verify(response, True) + + def disconnect(self, close=True): + """ + Logs off the session + + :param close: Will close all tree connects in a session + """ + if not self._connected: + # already disconnected so let's return + return + + if close: + for open in list(self.open_table.values()): + open.close(False) + + for tree in list(self.tree_connect_table.values()): + tree.disconnect() + + log.info("Session: %d - Logging off of SMB Session" % self.session_id) + logoff = SMB2Logoff() + log.info("Session: %d - Sending Logoff message" % self.session_id) + log.debug(str(logoff)) + request = self.connection.send(logoff, sid=self.session_id) + + log.info("Session: %d - Receiving Logoff response" % self.session_id) + res = self.connection.receive(request) + res_logoff = SMB2Logoff() + res_logoff.unpack(res['data'].get_value()) + log.debug(str(res_logoff)) + self._connected = False + del self.connection.session_table[self.session_id] + + def _authenticate_session(self, mech): + if mech in [MechTypes.KRB5, MechTypes.MS_KRB5] and HAVE_GSSAPI: + context = GSSAPIContext(username=self.username, + password=self.password, + server=self.connection.server_name) + elif mech in [MechTypes.KRB5, MechTypes.MS_KRB5, MechTypes.NTLMSSP] \ + and HAVE_SSPI: + raise NotImplementedError("SSPI on Windows for authentication is " + "not yet implemented") + elif mech == MechTypes.NTLMSSP: + context = NtlmContext(username=self.username, + password=self.password) + else: + raise NotImplementedError("Mech Type %s is not yet supported" + % mech) + + for out_token in context.step(): + session_setup = SMB2SessionSetupRequest() + session_setup['security_mode'] = \ + self.connection.client_security_mode + session_setup['buffer'] = out_token + + log.info("Sending SMB2_SESSION_SETUP request message") + request = self.connection.send(session_setup, + sid=self.session_id, + credit_request=256) + self.preauth_integrity_hash_value.append(request.message) + + log.info("Receiving SMB2_SESSION_SETUP response message") + try: + response = self.connection.receive(request) + except SMBResponseException as exc: + if exc.status != NtStatus.STATUS_MORE_PROCESSING_REQUIRED: + raise exc + mid = request.message['message_id'].get_value() + del self.connection.outstanding_requests[mid] + response = exc.header + + self.session_id = response['session_id'].get_value() + session_resp = SMB2SessionSetupResponse() + session_resp.unpack(response['data'].get_value()) + + context.in_token = session_resp['buffer'].get_value() + status = response['status'].get_value() + if status == NtStatus.STATUS_MORE_PROCESSING_REQUIRED: + log.info("More processing is required for SMB2_SESSION_SETUP") + self.preauth_integrity_hash_value.append(response) + + # Once the context is established, we need the session key which is + # used to derive the signing and sealing keys for SMB + session_key = context.get_session_key() + + return response, session_key + + def _smb3kdf(self, ki, label, context): + """ + See SMB 3.x key derivation function + https://blogs.msdn.microsoft.com/openspecification/2017/05/26/smb-2-and-smb-3-security-in-windows-10-the-anatomy-of-signing-and-cryptographic-keys/ + + :param ki: The session key is the KDK used as an input to the KDF + :param label: The purpose of this derived key as bytes string + :param context: The context information of this derived key as bytes + string + :return: Key derived by the KDF as specified by [SP800-108] 5.1 + """ + kdf = KBKDFHMAC( + algorithm=hashes.SHA256(), + mode=Mode.CounterMode, + length=16, + rlen=4, + llen=4, + location=CounterLocation.BeforeFixed, + label=label, + context=context, + fixed=None, + backend=default_backend() + ) + return kdf.derive(ki) + + +class NtlmContext(object): + + def __init__(self, username, password): + # try and get the domain part from the username + log.info("Setting up NTLM Security Context for user %s" % username) + try: + self.domain, self.username = username.split("\\", 1) + except ValueError: + self.username = username + self.domain = '' + self.password = password + self.context = Ntlm() + self.in_token = None + + def step(self): + log.info("NTLM: Generating Negotiate message") + msg1 = self.context.create_negotiate_message(self.domain) + msg1 = base64.b64decode(msg1) + log.debug("NTLM: Negotiate message: %s" % _bytes_to_hex(msg1)) + yield msg1 + + log.info("NTLM: Parsing Challenge message") + msg2 = base64.b64encode(self.in_token) + log.debug("NTLM: Challenge message: %s" % _bytes_to_hex(self.in_token)) + self.context.parse_challenge_message(msg2) + + log.info("NTLM: Generating Authenticate message") + msg3 = self.context.create_authenticate_message( + user_name=self.username, + password=self.password, + domain_name=self.domain + ) + yield base64.b64decode(msg3) + + def get_session_key(self): + return self.context.authenticate_message.exported_session_key + + +class GSSAPIContext(object): + + def __init__(self, username, password, server): + log.info("Setting up GSSAPI Security Context for Kerberos auth") + self.creds = self._acquire_creds(username, password) + + server_spn = "cifs@%s" % server + log.debug("GSSAPI Server SPN Target: %s" % server_spn) + server_name = gssapi.Name(base=server_spn, + name_type=gssapi.NameType.hostbased_service) + self.context = gssapi.SecurityContext(name=server_name, + creds=self.creds, + usage='initiate') + self.in_token = None + + def step(self): + while not self.context.complete: + log.info("GSSAPI: gss_init_sec_context called") + out_token = self.context.step(self.in_token) + if out_token: + yield out_token + else: + log.info("GSSAPI: gss_init_sec_context complete") + + def get_session_key(self): + # GSS_C_INQ_SSPI_SESSION_KEY + session_key_oid = gssapi.OID.from_int_seq("1.2.840.113554.1.2.2.5.5") + context_data = gssapi.raw.inquire_sec_context_by_oid(self.context, + session_key_oid) + + return context_data[0] + + def _acquire_creds(self, username, password): + # 3 use cases with Kerberos AUth + # 1. Both the user and pass is supplied so we want to create a new + # ticket with the pass + # 2. Only the user is supplied so we will attempt to get the cred + # from the existing store + # 3. The user is not supplied so we will attempt to get the default + # cred from the existing store + log.info("GSSAPI: Acquiring credentials handle") + if username and password: + log.debug("GSSAPI: Acquiring credentials handle for user %s with " + "password" % username) + user = gssapi.Name(base=username, + name_type=gssapi.NameType.user) + bpass = password.encode('utf-8') + try: + creds = gssapi.raw.acquire_cred_with_password(user, bpass, + usage='initiate') + except AttributeError: + raise SMBAuthenticationError("Cannot get GSSAPI credential " + "with password as the necessary " + "GSSAPI extensions are not " + "available") + except gssapi.exceptions.GSSError as er: + raise SMBAuthenticationError("Failed to acquire GSSAPI " + "credential with password: %s" + % str(er)) + # acquire_cred_with_password returns a wrapper, we want the creds + # object inside this wrapper + creds = creds.creds + elif username: + log.debug("GSSAPI: Acquiring credentials handle for user %s from " + "existing cache" % username) + user = gssapi.Name(base=username, + name_type=gssapi.NameType.user) + + try: + creds = gssapi.Credentials(name=user, usage='initiate') + except gssapi.exceptions.MissingCredentialsError as er: + raise SMBAuthenticationError("Failed to acquire GSSAPI " + "credential for user %s from the " + "exisiting cache: %s" + % (str(user), str(er))) + else: + log.debug("GSSAPI: Acquiring credentials handle for default user " + "in cache") + try: + creds = gssapi.Credentials(name=None, usage='initiate') + except gssapi.exceptions.GSSError as er: + raise SMBAuthenticationError("Failed to acquire default " + "GSSAPI credential from the " + "existing cache: %s" % str(er)) + user = creds.name + + log.info("GSSAPI: Acquired credentials for user %s" % str(user)) + return creds diff --git a/smbprotocol/spnego.py b/smbprotocol/spnego.py new file mode 100644 index 00000000..1e104581 --- /dev/null +++ b/smbprotocol/spnego.py @@ -0,0 +1,285 @@ +from pyasn1.type.char import GeneralString +from pyasn1.type.constraint import SingleValueConstraint +from pyasn1.type.namedtype import NamedType, NamedTypes, OptionalNamedType +from pyasn1.type.namedval import NamedValues +from pyasn1.type.tag import Tag, tagClassApplication, tagClassContext, \ + tagFormatConstructed, tagFormatSimple, TagSet +from pyasn1.type.univ import BitString, Choice, Enumerated, ObjectIdentifier, \ + OctetString, Sequence, SequenceOf + + +class MechTypes(object): + # Currently only NTLMSSP is supported, with the aim to support Kerberos + MS_KRB5 = ObjectIdentifier('1.2.840.48018.1.2.2') + KRB5 = ObjectIdentifier('1.2.840.113554.1.2.2') + KRB5_U2U = ObjectIdentifier('1.2.840.113554.1.2.2.3') + NEGOEX = ObjectIdentifier('1.3.6.1.4.1.311.2.2.30') + NTLMSSP = ObjectIdentifier('1.3.6.1.4.1.311.2.2.10') + + +class MechType(ObjectIdentifier): + """ + [RFC-4178] + + 4.1 Mechanism Types + OID represents one GSS-API mechanism according to RFC-2743. + + MechType ::= OBJECT IDENTIFIER + """ + pass + + +class MechTypeList(SequenceOf): + """ + [RFC-4178] + + 4.1 Mechanism Types + List of MechTypes + + MechTypeList ::= SEQUENCE OF MechType + """ + componentType = MechType() + + +class ContextFlags(BitString): + """ + [RFC-4178] + + ContextFlags ::= BIT STRING { + delegFlag (0), + mutualFlag (1), + replayFlag (2), + sequenceFlag (3), + anonFlag (4), + confFlag (5), + integFlag (6) + } + """ + componentType = NamedValues( + ('delegFlag', 0), + ('mutualFlag', 1), + ('replayFlag', 2), + ('sequenceFlag', 3), + ('anonFlag', 4), + ('confFlag', 5), + ('integFlag', 6) + ) + + +class NegStat(Enumerated): + """ + [RFC-4178] + + NegState ::= ENUMERATED { + accept-completed (0), + accept-incomplete (1), + reject (2), + request-mic (3) + } + """ + namedValues = NamedValues( + ('accept-complete', 0), + ('accept-incomplete', 1), + ('reject', 2), + ('request-mic', 3) + ) + subtypeSpec = Enumerated.subtypeSpec + SingleValueConstraint(0, 1, 2, 3) + + +class NegHints(Sequence): + """ + [MS-SPNG] v14.0 2017-09-15 + + 2.2.1 NegTokenInit2 + NegHints is an extension of NegTokenInit. + + NegHints ::= SEQUENCE { + hintName[0] GeneralString OPTIONAL, + hintAddress[1] OCTET STRING OPTIONAL + } + """ + componentType = NamedTypes( + OptionalNamedType( + 'hintName', GeneralString().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 0) + ) + ), + OptionalNamedType( + 'hintAddress', OctetString().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 1) + ) + ) + ) + + +class NegTokenInit(Sequence): + """ + [RFC-4178] + + NegTokenInit ::= SEQUENCE { + mechTypes [0] MechTypeList, + regFlags [1] ContextFlags OPTIONAL, + mechToken [2] OCTET STRING OPTIONAL, + mechListMIC [3] OCTER STRING OPTIONAL, + ... + } + """ + componentType = NamedTypes( + NamedType( + 'mechTypes', MechTypeList().subtype( + explicitTag=Tag(tagClassContext, tagFormatConstructed, 0) + ) + ), + OptionalNamedType( + 'reqFlags', ContextFlags().subtype( + explicitTag=Tag(tagClassContext, tagFormatConstructed, 1) + ) + ), + OptionalNamedType( + 'mechToken', OctetString().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 2) + ) + ), + OptionalNamedType( + 'mechListMIC', OctetString().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 3) + ) + ) + ) + + +class NegTokenInit2(Sequence): + """ + [MS-SPNG] v14.0 2017-09-15 + + 2.2.1 NegTokenInit2 + NegTokenInit2 is the message structure that extends NegTokenInit with a + negotiation hints (negHints) field. On a server initiated SPNEGO process, + it sends negTokenInit2 message instead of just the plain NegTokenInit. + + NegTokenInit2 ::= SEQUENCE { + mechTypes [0] MechTypeList OPTIONAL, + reqFlags [1] ContextFlags OPTIONAL, + mechToken [2] OCTET STRING OPTIONAL, + negHints [3] NegHints OPTIONAL, + mechListMIC [4] OCTET STRING OPTIONAL, + ... + } + """ + componentType = NamedTypes( + OptionalNamedType( + 'mechTypes', MechTypeList().subtype( + explicitTag=Tag(tagClassContext, tagFormatConstructed, 0) + ) + ), + OptionalNamedType( + 'reqFlags', ContextFlags().subtype( + explicitTag=Tag(tagClassContext, tagFormatConstructed, 1) + ) + ), + OptionalNamedType( + 'mechToken', OctetString().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 2) + ) + ), + OptionalNamedType( + 'negHints', NegHints().subtype( + explicitTag=Tag(tagClassContext, tagFormatConstructed, 3) + ) + ), + OptionalNamedType( + 'mechListMIC', OctetString().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 4) + ) + ), + ) + + +class NegTokenResp(Sequence): + """ + [RFC-4178] + + 4.2.2 negTokenResp + The response message for NegTokenInit. + + NegTokenResp ::= SEQUENCE { + negStat [0] NegState OPTIONAL, + supportedMech [1] MechType OPTIONAL, + responseToken [2] OCTET STRING OPTIONAL, + mechListMIC {3] OCTET STRING OPTIONAL, + ... + } + """ + componentType = NamedTypes( + OptionalNamedType( + 'negStat', NegStat().subtype( + explicitTag=Tag(tagClassContext, tagFormatConstructed, 0) + ) + ), + OptionalNamedType( + 'supportedMech', ObjectIdentifier().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 1) + ) + ), + OptionalNamedType( + 'responseToken', OctetString().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 2) + ) + ), + OptionalNamedType( + 'mechListMIC', OctetString().subtype( + explicitTag=Tag(tagClassContext, tagFormatSimple, 3) + ) + ) + ) + + +class NegotiateToken(Choice): + """ + [RFC-4178] + + NegotiateToken ::= CHOICE { + negTokenInit [0] NegTokenInit, + negTokenResp [1] NegTokenResp + } + """ + componentType = NamedTypes( + NamedType( + 'negTokenInit', NegTokenInit2().subtype( + explicitTag=Tag(tagClassContext, tagFormatConstructed, 0) + ) + ), + NamedType( + 'negTokenResp', NegTokenResp().subtype( + explicitTag=Tag(tagClassContext, tagFormatConstructed, 1) + ) + ) + ) + + +class InitialContextToken(Sequence): + """ + [RFC-2743] + + 3.1. Mechanism-Independent Token Format + This section specifies a mechanism-independent level of encapsulating + representation for the initial token of a GSS-API context establishment + sequence. + + InitialContextToken ::= [APPLICATION 0] IMPLICIT SEQUENCE { + thisMech MechType, + innerContextToken NegotiateToken + } + """ + componentType = NamedTypes( + NamedType( + 'thisMech', ObjectIdentifier() + ), + NamedType( + 'innerContextToken', NegotiateToken() + ) + ) + tagSet = TagSet( + Sequence.tagSet, + Tag(tagClassApplication, tagFormatConstructed, 0), + ) diff --git a/smbprotocol/structure.py b/smbprotocol/structure.py new file mode 100644 index 00000000..9bd6aeb2 --- /dev/null +++ b/smbprotocol/structure.py @@ -0,0 +1,869 @@ +import copy +import struct +import textwrap +import types +import uuid +from abc import ABCMeta, abstractmethod +from binascii import hexlify +from datetime import datetime, timedelta + +from six import with_metaclass, integer_types + +TAB = " " # Instead of displaying a tab on the print, use 4 spaces + + +class InvalidFieldDefinition(Exception): + pass + + +def _bytes_to_hex(bytes, pretty=False, hex_per_line=8): + hex = hexlify(bytes).decode('utf-8') + + if pretty: + if hex_per_line == 0: # show hex on 1 line + hex_list = [hex] + else: + idx = hex_per_line * 2 + hex_list = list(hex[i:i + idx] for i in range(0, len(hex), idx)) + + hexes = [] + for h in hex_list: + hexes.append( + ' '.join(h[i:i + 2] for i in range(0, len(h), 2)).upper()) + hex = "\n".join(hexes) + + return hex + + +def _indent_lines(string, prefix): + # Would use textwrap.indent for this but it is not available for Python 2 + def predicate(line): + return line.strip() + + lines = [] + for line in string.splitlines(True): + lines.append(prefix + line if predicate(line) else line) + return ''.join(lines) + + +class Structure(object): + + def __init__(self): + # Now that self.fields is set, loop through it again and set the + # metadata around the fields and set the value based on default. + # This must be done outside of the OrderedDict definition as set_value + # relies on the full structure (self) being available and error + # messages use the field name to be helpful + for name, field in self.fields.items(): + field.structure = self + field.name = name + field.set_value(field.default) + + def __str__(self): + struct_name = self.__class__.__name__ + raw_hex = _bytes_to_hex(self.pack(), True, hex_per_line=0) + field_strings = [] + + for name, field in self.fields.items(): + # the field header is slightly different for a StructureField + # remove the leading space and put the value on the next line + if isinstance(field, StructureField): + field_header = "%s =\n%s" + else: + field_header = "%s = %s" + + field_string = field_header % (field.name, str(field)) + field_strings.append(_indent_lines(field_string, TAB)) + + field_strings.append("") + field_strings.append(_indent_lines("Raw Hex:", TAB)) + hex_wrapper = textwrap.TextWrapper( + width=33, # set to show 8 hex values per line, 33 for 8, 56 for 16 + initial_indent=TAB + TAB, + subsequent_indent=TAB + TAB + ) + field_strings.append(hex_wrapper.fill(raw_hex)) + + string = "%s:\n%s" % (struct_name, '\n'.join(field_strings)) + + return string + + def __setitem__(self, key, value): + field = self._get_field(key) + field.set_value(value) + + def __getitem__(self, key): + return self._get_field(key) + + def __delitem__(self, key): + self._get_field(key) + del self.fields[key] + + def __len__(self): + length = 0 + for field in self.fields.values(): + length += len(field) + return length + + def pack(self): + data = b"" + for field in self.fields.values(): + field_data = field.pack() + data += field_data + + return data + + def unpack(self, data): + for key, field in self.fields.items(): + data = field.unpack(data) + return data # remaining data + + def _get_field(self, key): + field = self.fields.get(key, None) + if field is None: + raise ValueError("Structure does not contain field %s" % key) + return field + + +class Field(with_metaclass(ABCMeta, object)): + + def __init__(self, little_endian=True, default=None, size=None): + """ + The base class of a Field object. This contains the framework that a + field SHOULD implement in regards to packing and unpacking a value. + There should be little need to call this particular object as it is + designed to be a base class for *Type classes. + + :param little_endian: When converting an int/uuid to bytes, the byte + order to pack as, False means it will be big endian + :param default: The default value of the field, this can be any + supported value such as as well as a lambda function or None + (default). + :param size: The size of the field, this can be an int, lambda function + or None (for variable length end field) unless overridden in Class + definition. + """ + field_type = self.__class__.__name__ + self.little_endian = little_endian + + if not (size is None or isinstance(size, integer_types) or + isinstance(size, types.LambdaType)): + raise InvalidFieldDefinition("%s size for field must be an int or " + "None for a variable length" + % field_type) + self.size = size + self.default = default + self.value = None + + def __str__(self): + return self._to_string() + + def __len__(self): + return self._get_packed_size() + + def pack(self): + """ + Packs the field value into a byte string so it can be sent to the + server. + + :param structure: The message structure class object + :return: A byte string of the packed field's value + """ + value = self._get_calculated_value(self.value) + packed_value = self._pack_value(value) + size = self._get_calculated_size(self.size, packed_value) + if len(packed_value) != size: + raise ValueError("Invalid packed data length for field %s of %d " + "does not fit field size of %d" + % (self.name, len(packed_value), size)) + + return packed_value + + def get_value(self): + """ + Returns the value set for the field, will run any lambda functions + that is set under the value attribute and return the final value. + + :return: The value attribute with lambda functions run if value is a + lambda function + """ + return self._get_calculated_value(self.value) + + def set_value(self, value): + """ + Parses, and sets the value attribute for the field. + + :param value: The value to be parsed and set, the allowed input types + vary depending on the Field used + """ + parsed_value = self._parse_value(value) + self.value = parsed_value + + def unpack(self, data): + """ + Takes in a byte string and set's the field value based on field + definition. + + :param structure: The message structure class object + :param data: The byte string of the data to unpack + :return: The remaining data for subsequent fields + """ + size = self._get_calculated_size(self.size, data) + self.set_value(data[0:size]) + return data[len(self):] + + @abstractmethod + def _pack_value(self, value): + """ + Packs the value passed in according to the rules of the FieldType. + + :param value: The value to be packed, this is derived by + _get_calculated_value(self.value) + :return: A byte string of the data once packed + """ + pass # pragma: no cover + + @abstractmethod + def _parse_value(self, value): + """ + Parses the value into the FieldType type, this also validates that + the value is allowable by the FieldType. + + :param value: The value to parse + :return: The value that has been parsed/casted to the correct value + """ + pass # pragma: no cover + + @abstractmethod + def _get_packed_size(self): + """ + Get's the size of the data once it has been packed. Depending on the + FieldType, this can either be pre-set or calculated when called. + + :return: The size of the field once it is packed + """ + pass # pragma: no cover + + @abstractmethod + def _to_string(self): + """ + Creates a string which is a human readable representation of the value. + The output is dependent on the field implementation. + + :return: string of the field value + """ + # creates a string which is a friendly representation of the value + pass # pragma: no cover + + def _get_calculated_value(self, value): + """ + Get's the final value of the field and runs the lambda functions + recursively until a final value is derived. + + :param value: The value to calculate/expand + :return: The final value + """ + if isinstance(value, types.LambdaType): + expanded_value = value(self.structure) + return self._get_calculated_value(expanded_value) + else: + # perform one final parsing of the value in case lambda value + # returned a different type + return self._parse_value(value) + + def _get_calculated_size(self, size, data): + """ + Get's the final size of the field and runs the lambda functions + recursively until a final size is derived. If size is None then it + will just return the length of the data as it is assumed it is the + final field (None should only be set on size for the final field). + + :param size: The size to calculate/expand + :param data: The data that the size is being calculated for + :return: The final size + """ + # if the size is derived from a lambda function, run it now; otherwise + # return the value we passed in or the length of the data if the size + # is None (last field value) + if size is None: + return len(data) + elif isinstance(size, types.LambdaType): + expanded_size = size(self.structure) + return self._get_calculated_size(expanded_size, data) + else: + return size + + def _get_struct_format(self, size): + """ + Get's the format specified for use in struct. This is only designed + for 1, 2, 4, or 8 byte values and will throw an exception if it is + anything else. + + :param size: The size as an int + :return: The struct format specifier for the size specified + """ + if isinstance(size, types.LambdaType): + size = size(self.structure) + + struct_format = { + 1: 'B', + 2: 'H', + 4: 'L', + 8: 'Q' + } + if size not in struct_format.keys(): + raise InvalidFieldDefinition("Cannot struct format of size %s" + % size) + return struct_format[size] + + +class IntField(Field): + + def __init__(self, size, **kwargs): + """ + Used to store an int value for a field. The size for these values MUST + be 1, 2, 4, or 8 and if another size is required use the BytesField + instead and store the values as bytes. + + :param size: The size of the integer when packed + :param kwargs: Any other kwarg to be sent to Field() + """ + if size not in [1, 2, 4, 8]: + raise InvalidFieldDefinition("IntField size must have a value of " + "1, 2, 4, or 8 not %s" % str(size)) + super(IntField, self).__init__(size=size, **kwargs) + + def _pack_value(self, value): + format = self._get_struct_format(self.size) + struct_string = "%s%s" % ("<" if self.little_endian else ">", format) + packed_int = struct.pack(struct_string, value) + return packed_int + + def _parse_value(self, value): + if value is None: + int_value = 0 + elif isinstance(value, types.LambdaType): + int_value = value + elif isinstance(value, bytes): + format = self._get_struct_format(self.size) + struct_string = "%s%s"\ + % ("<" if self.little_endian else ">", format) + int_value = struct.unpack(struct_string, value)[0] + elif isinstance(value, integer_types): + int_value = value + else: + raise TypeError("Cannot parse value for field %s of type %s to " + "an int" % (self.name, type(value).__name__)) + return int_value + + def _get_packed_size(self): + return self.size + + def _to_string(self): + return str(self._get_calculated_value(self.value)) + + +class BytesField(Field): + """ + Used to store a raw bytes value as a field. Is the most universal and can + convert from most objects to a bytes string. Use this is the field can + contain multiple values and parsing will be done outside of the class. + """ + + def _pack_value(self, value): + return value + + def _parse_value(self, value): + if value is None: + bytes_value = b"" + elif isinstance(value, types.LambdaType): + bytes_value = value + elif isinstance(value, integer_types): + format = self._get_struct_format(self.size) + struct_string = "%s%s"\ + % ("<" if self.little_endian else ">", format) + bytes_value = struct.pack(struct_string, value) + elif isinstance(value, Structure): + bytes_value = value.pack() + elif isinstance(value, bytes): + bytes_value = value + else: + raise TypeError("Cannot parse value for field %s of type %s to a " + "byte string" % (self.name, type(value).__name__)) + return bytes_value + + def _get_packed_size(self): + bytes_value = self._get_calculated_value(self.value) + return len(bytes_value) + + def _to_string(self): + bytes_value = self._get_calculated_value(self.value) + return _bytes_to_hex(bytes_value, pretty=True, hex_per_line=0) + + +class ListField(Field): + + def __init__(self, list_count=None, list_type=BytesField(), + unpack_func=None, **kwargs): + """ + Used to store a list of values that are the same time, the list can + contain both fixed length values or variable length values but the + former is easier to use as it does not require lambda functions to + unpack the values. If the list values are different types, then the + BytesField list_type should be used and the data will automatically + will be converted to a bytes object. If appending a value to the list, + ensure the value it added as an actual *Field() object and not just + the raw value. + + :param list_count: The number of entries in the list, the value can be + an int, lambda function or None (for variable length). The lambda + function is only evaluated in the pack and unpack methods. This + must be set if unpack_func is not set so it can unpack the data + receved from the server. + :param list_type: The *Field() definition for each list entry, defaults + to a variable length BytesField. If unpack_func is not set, the + size attribute must be set. + :param unpack_func: A lambda function used during the unpack method to + unpack the data received from the server to a list. It takes in the + (structure, data) arguments which is the structure of the whole + packet and the remaining data left to be unpacked. This MUST be + used when the list contains variable length values. + :param kwargs: Any other kwarg to be sent to Field() + """ + if list_count is not None and not \ + (isinstance(list_count, integer_types) or + isinstance(list_count, types.LambdaType)): + raise InvalidFieldDefinition("ListField list_count must be an " + "int, lambda, or None for a variable " + "list length") + self.list_count = list_count + + if not isinstance(list_type, Field): + raise InvalidFieldDefinition("ListField list_type must be a " + "Field definition") + self.list_type = list_type + + if unpack_func is not None and not isinstance(unpack_func, + types.LambdaType): + raise InvalidFieldDefinition("ListField unpack_func must be a " + "lambda function or None") + elif unpack_func is None and \ + (list_count is None or list_type.size is None): + raise InvalidFieldDefinition("ListField must either define " + "unpack_func as a lambda or set " + "list_count and list_size with a " + "size") + self.unpack_func = unpack_func + + super(ListField, self).__init__(**kwargs) + + def __getitem__(self, item): + # TODO: Make this more efficient + return self.get_value()[item] + + def get_value(self): + # Override default get_value() so we return a list with the actual + # value, not the Field definition + list_value = [] + if isinstance(self.value, types.LambdaType): + value = self._get_calculated_value(self.value) + else: + value = self.value + + for entry in value: + list_value.append(entry.get_value()) + return list_value + + def _pack_value(self, value): + data = b"" + for value in list(value): + data += value.pack() + return data + + def _parse_value(self, value): + if value is None: + list_value = [] + elif isinstance(value, types.LambdaType): + return value + elif isinstance(value, bytes) and isinstance(self.unpack_func, + types.LambdaType): + # use the lambda function to parse the bytes to a list + list_value = self.unpack_func(self.structure, value) + elif isinstance(value, bytes): + # we have a fixed length array with a specified count + list_value = self._create_list_from_bytes(self.list_count, + self.list_type, value) + elif isinstance(value, list): + # manually parse each list entry to the field type specified + list_value = value + else: + raise TypeError("Cannot parse value for field %s of type %s to a " + "list" % (self.name, type(value).__name__)) + list_value = [self._parse_sub_value(v) for v in list_value] + return list_value + + def _parse_sub_value(self, value): + if isinstance(value, Field): + new_field = value + elif isinstance(value, Structure): + new_field = StructureField( + size=len(value), + structure_type=type(value), + default=value, + ) + new_field.name = "%s list entry" % self.name + new_field.structure = value + new_field.set_value(new_field.default) + else: + new_field = copy.deepcopy(self.list_type) + new_field.name = "%s list entry" % self.name + new_field.set_value(value) + return new_field + + def _get_packed_size(self): + list_value = self._get_calculated_value(self.value) + size = 0 + for field in list(list_value): + size += len(field) + return size + + def _to_string(self): + list_value = self._get_calculated_value(self.value) + list_string = [_indent_lines(str(v), TAB) for v in list(list_value)] + if len(list_string) == 0: + string = "[]" + else: + string = "[\n%s\n]" % ',\n'.join(list_string) + return string + + def _create_list_from_bytes(self, list_count, list_type, value): + # calculate the list_count and rerun method if a lambda + if isinstance(list_count, types.LambdaType): + list_count = list_count(self.structure) + return self._create_list_from_bytes(list_count, list_type, value) + + list_value = [] + for idx in range(0, list_count): + new_field = copy.deepcopy(list_type) + value = new_field.unpack(value) + list_value.append(new_field) + return list_value + + +class StructureField(Field): + + def __init__(self, structure_type, **kwargs): + """ + Used to store a message packet Structure object as a field. Can store + both an actual Structure value or a byte string. + + :param structure_type: The message structure type, e.g. + SMB2NegotiateRequest. Used to marshal a byte string to a structure + object when unpacking or setting a value + :param kwargs: Any other kwarg to be sent to Field() + """ + self.structure_type = structure_type + super(StructureField, self).__init__(**kwargs) + + def __setitem__(self, key, value): + field = self._get_field(key) + field.set_value(value) + + def __getitem__(self, key): + return self._get_field(key) + + def set_structure_type(self, structure_type): + # Set's the structure type and convert a byte string to the actual + # structure specified + self.structure_type = structure_type + self.set_value(self.value) + + def _pack_value(self, value): + # Can either be a Structure or just plain bytes, just pack the + # structure if needed + if isinstance(value, Structure): + value = value.pack() + return value + + def _parse_value(self, value): + if value is None: + structure_value = b"" + elif isinstance(value, types.LambdaType): + structure_value = value + elif isinstance(value, bytes): + structure_value = value + elif isinstance(value, Structure): + structure_value = value + else: + raise TypeError("Cannot parse value for field %s of type %s to a " + "structure" % (self.name, type(value).__name__)) + + if isinstance(structure_value, bytes) and self.structure_type and \ + structure_value != b"": + if isinstance(self.structure_type, types.LambdaType): + structure_type = self.structure_type(self.structure) + else: + structure_type = self.structure_type + structure = structure_type() + structure.unpack(structure_value) + structure_value = structure + return structure_value + + def _get_packed_size(self): + structure_value = self._get_calculated_value(self.value) + return len(structure_value) + + def _to_string(self): + structure_value = self._get_calculated_value(self.value) + return str(structure_value) + + def _get_field(self, key): + structure_value = self._get_calculated_value(self.value) + if isinstance(structure_value, bytes): + raise ValueError("Cannot get field %s when structure is defined " + "as a byte string" % key) + field = structure_value._get_field(key) + return field + + +class DateTimeField(Field): + + EPOCH_FILETIME = 116444736000000000 # epoch as a MS FILETIME int + HUNDREDS_NS = 10000000 # How many hundred nanoseconds in a second + + def __init__(self, size=None, **kwargs): + """ + [MS-DTYP] 0.0 2017-09-15 + + 2.3.3 FILETIME + The FILETIME structure is a 64-it value that represents the number of + 100 nanoseconds intervals that have elapsed since January 1, 1601 UTC. + This is used to convert the FILETIME int value to a native Python + datetime object. + + While the format FILETIME is used when communicating with the server, + this type allows Python code to interact with datetime objects natively + with all the conversions handled at pack/unpack time. + + :param size: Must be set to None or 8, this is so we can check/override + :param kwargs: Any other kwarg to be sent to Field() + """ + if not (size is None or size == 8): + raise InvalidFieldDefinition("DateTimeField type must have a size " + "of 8 not %d" % size) + super(DateTimeField, self).__init__(size=8, **kwargs) + + def _pack_value(self, value): + epoch_seconds = self._seconds_since_epoch(value) + int_value = self.EPOCH_FILETIME + (epoch_seconds * self.HUNDREDS_NS) + int_value += value.microsecond * 10 + + format = self._get_struct_format(8) + struct_string = "%s%s"\ + % ("<" if self.little_endian else ">", format) + bytes_value = struct.pack(struct_string, int_value) + + return bytes_value + + def _parse_value(self, value): + if value is None: + datetime_value = datetime.today() + elif isinstance(value, types.LambdaType): + datetime_value = value + elif isinstance(value, bytes): + format = self._get_struct_format(8) + struct_string = "%s%s"\ + % ("<" if self.little_endian else ">", format) + int_value = struct.unpack(struct_string, value)[0] + return self._parse_value(int_value) # just parse the value again + elif isinstance(value, integer_types): + + time_microseconds = (value - self.EPOCH_FILETIME) // 10 + datetime_value = datetime(1970, 1, 1) + \ + timedelta(microseconds=time_microseconds) + elif isinstance(value, datetime): + datetime_value = value + else: + raise TypeError("Cannot parse value for field %s of type %s to a " + "datetime" % (self.name, type(value).__name__)) + return datetime_value + + def _get_packed_size(self): + return self.size + + def _to_string(self): + datetime_value = self._get_calculated_value(self.value) + return datetime_value.isoformat(' ') + + def _seconds_since_epoch(self, datetime_value): + # total_seconds was not present in Python 2.6, this is suggested by + # Python docs as an alternative + # https://docs.python.org/2/library/datetime.html#datetime.timedelta.total_seconds + td = datetime_value - datetime.utcfromtimestamp(0) + seconds = (td.microseconds + + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / 10 ** 6 + return int(seconds) + + +class UuidField(Field): + + def __init__(self, size=None, **kwargs): + """ + Used to store a UUID (GUID) as a Python UUID object. + + :param size: Must be set to None or 16, this is so we can + check/override + :param kwargs: Any other kwarg to be sent to Field() + """ + if not (size is None or size == 16): + raise InvalidFieldDefinition("UuidField type must have a size of " + "16 not %d" % size) + super(UuidField, self).__init__(size=16, **kwargs) + + def _pack_value(self, value): + if self.little_endian: + return value.bytes + else: + return value.bytes_le + + def _parse_value(self, value): + if value is None: + uuid_value = uuid.UUID(bytes=b"\x00" * 16) + elif isinstance(value, bytes) and self.little_endian: + uuid_value = uuid.UUID(bytes=value) + elif isinstance(value, bytes) and not self.little_endian: + uuid_value = uuid.UUID(bytes_le=value) + elif isinstance(value, integer_types): + uuid_value = uuid.UUID(int=value) + elif isinstance(value, uuid.UUID): + uuid_value = value + elif isinstance(value, types.LambdaType): + uuid_value = value + else: + raise TypeError("Cannot parse value for field %s of type %s to a " + "uuid" % (self.name, type(value).__name__)) + return uuid_value + + def _get_packed_size(self): + return self.size + + def _to_string(self): + uuid_value = self._get_calculated_value(self.value) + return str(uuid_value) + + +class EnumField(IntField): + + def __init__(self, enum_type, enum_strict=True, **kwargs): + self.enum_type = enum_type + self.enum_strict = enum_strict + super(EnumField, self).__init__(**kwargs) + + def _parse_value(self, value): + int_value = super(EnumField, self)._parse_value(value) + valid = False + for flag_value in vars(self.enum_type).values(): + if int_value == flag_value: + valid = True + break + + if not valid and int_value != 0 and self.enum_strict: + raise ValueError("Enum value %d does not exist in enum type %s" + % (int_value, self.enum_type)) + return int_value + + def _to_string(self): + enum_name = None + value = self._get_calculated_value(self.value) + for enum, enum_value in vars(self.enum_type).items(): + if value == enum_value: + enum_name = enum + break + if enum_name is None: + return "(%d) UNKNOWN_ENUM" % value + else: + return "(%d) %s" % (value, enum_name) + + +class FlagField(IntField): + + def __init__(self, flag_type, flag_strict=True, **kwargs): + self.flag_type = flag_type + self.flag_strict = flag_strict + super(FlagField, self).__init__(**kwargs) + + def set_flag(self, flag): + valid = False + for value in vars(self.flag_type).values(): + if flag == value: + valid = True + break + + if not valid and self.flag_strict: + raise ValueError("Flag value does not exist in flag type %s" + % self.flag_type) + self.set_value(self.value | flag) + + def has_flag(self, flag): + return self.value & flag == flag + + def _parse_value(self, value): + int_value = super(FlagField, self)._parse_value(value) + current_val = int_value + for value in vars(self.flag_type).values(): + if isinstance(value, int): + current_val &= ~value + if current_val != 0 and self.flag_strict: + raise ValueError("Invalid flag for field %s value set %d" + % (self.name, current_val)) + + return int_value + + def _to_string(self): + field_value = self._get_calculated_value(self.value) + if field_value == 0: + return "0" + flags = [] + for flag, value in vars(self.flag_type).items(): + if isinstance(value, int) and self.has_flag(value): + flags.append(flag) + flags.sort() + return "(%d) %s" % (field_value, ", ".join(flags)) + + +class BoolField(Field): + + def __init__(self, size=1, **kwargs): + """ + Used to store a boolean value in 1 byte. b"\x00" is False while b"\x01" + is True. + + :param kwargs: Any other kwargs to be sent to Field() + """ + if size != 1: + raise InvalidFieldDefinition("BoolField size must have a value of " + "1, not %d" % size) + super(BoolField, self).__init__(size=size, **kwargs) + + def _pack_value(self, value): + return b"\x01" if value else b"\x00" + + def _parse_value(self, value): + if value is None: + bool_value = False + elif isinstance(value, bool): + bool_value = value + elif isinstance(value, bytes): + bool_value = value == b"\x01" + elif isinstance(value, types.LambdaType): + bool_value = value + else: + raise TypeError("Cannot parse value for field %s of type %s to a " + "bool" % (self.name, type(value).__name__)) + return bool_value + + def _get_packed_size(self): + return 1 + + def _to_string(self): + return str(self._get_calculated_value(self.value)) diff --git a/smbprotocol/transport.py b/smbprotocol/transport.py new file mode 100644 index 00000000..f4b99f01 --- /dev/null +++ b/smbprotocol/transport.py @@ -0,0 +1,109 @@ +import logging +import socket +import struct + +from multiprocessing.dummy import Process, Queue + +from smbprotocol.structure import BytesField, IntField, Structure + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + +log = logging.getLogger(__name__) + + +class DirectTCPPacket(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.1 Transport + The Directory TCP transport packet header MUST have the following + structure. + """ + + def __init__(self): + self.fields = OrderedDict([ + ('stream_protocol_length', IntField( + size=4, + little_endian=False, + default=lambda s: len(s['smb2_message']), + )), + ('smb2_message', BytesField( + size=lambda s: s['stream_protocol_length'].get_value(), + )), + ]) + super(DirectTCPPacket, self).__init__() + + +class Tcp(object): + + MAX_SIZE = 16777215 + + def __init__(self, server, port): + log.info("Setting up DirectTcp connection on %s:%d" % (server, port)) + self.message_buffer = Queue() + self.server = server + self.port = port + + self._connected = False + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._listener = Process(target=self._listen, + args=(self._sock, self.message_buffer)) + + def connect(self): + if not self._connected: + log.info("Connecting to DirectTcp socket") + self._sock.connect((self.server, self.port)) + self._connected = True + + if not self._listener.is_alive(): + log.info("Setting up DirectTcp listener") + self._listener.start() + + def disconnect(self): + if self._connected: + log.info("Disconnecting DirectTcp socket") + try: + self._sock.shutdown(socket.SHUT_RDWR) + except socket.error: + # socket has already been shutdown + pass + self._listener.join() + self._sock.close() + self._connected = False + + def send(self, request): + data_length = len(request) + if data_length > self.MAX_SIZE: + raise ValueError("Data to be sent over Direct TCP size %d exceeds " + "the max length allowed %d" + % (data_length, self.MAX_SIZE)) + + tcp_packet = DirectTCPPacket() + tcp_packet['smb2_message'] = request + data = tcp_packet.pack() + self._sock.sendall(data) + + @staticmethod + def _listen(sock, message_buffer): + """ + Runs in a thread and is constantly reading from the socket receive + buffer and adding each message to the queue. Very little error handling + and message parsing is done in this process as it happens + asynchronously to the main process + + :param sock: The socket to read from + :param message_buffer: A queue used to store the incoming messages for + Connection to read from + """ + while True: + packet_size_bytes = sock.recv(4) + # the socket was closed so exit the loop + if not packet_size_bytes: + break + + packet_size_int = struct.unpack(">L", packet_size_bytes)[0] + buffer = sock.recv(packet_size_int) + message_buffer.put(buffer) diff --git a/smbprotocol/tree.py b/smbprotocol/tree.py new file mode 100644 index 00000000..68d4582d --- /dev/null +++ b/smbprotocol/tree.py @@ -0,0 +1,337 @@ +import logging + +from smbprotocol.connection import Commands, Dialects +from smbprotocol.exceptions import SMBException +from smbprotocol.ioctl import CtlCode, IOCTLFlags, SMB2IOCTLRequest, \ + SMB2IOCTLResponse, SMB2ValidateNegotiateInfoRequest, \ + SMB2ValidateNegotiateInfoResponse +from smbprotocol.structure import BytesField, EnumField, FlagField, IntField, \ + Structure + +try: + from collections import OrderedDict +except ImportError: # pragma: no cover + from ordereddict import OrderedDict + +log = logging.getLogger(__name__) + + +class TreeFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.9 SMB2 TREE_CONNECT Response Flags + Flags used in SMB 3.1.1 to indicate how to process the operation. + """ + SMB2_TREE_CONNECT_FLAG_CLUSTER_RECONNECT = 0x0004 + SMB2_TREE_CONNECT_FLAG_REDIRECT_TO_OWNER = 0x0002 + SMB2_TREE_CONNECT_FLAG_EXTENSION_PRESENT = 0x0001 + + +class ShareType(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.10 SMB2 TREE_CONNECT Response Capabilities + The type of share being accessed + """ + SMB2_SHARE_TYPE_DISK = 0x01 + SMB2_SHARE_TYPE_PIPE = 0x02 + SMB2_SHARE_TYPE_PRINT = 0x03 + + +class ShareFlags(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.10 SMB2 TREE_CONNECT Response Capabilities + Properties for the share + """ + SMB2_SHAREFLAG_MANUAL_CACHING = 0x00000000 + SMB2_SHAREFLAG_AUTO_CACHING = 0x00000010 + SMB2_SHAREFLAG_VDO_CACHING = 0x00000020 + SMB2_SHAREFLAG_NO_CACHING = 0x00000030 + SMB2_SHAREFLAG_DFS = 0x00000001 + SMB2_SHAREFLAG_DFS_ROOT = 0x00000002 + SMB2_SHAREFLAG_RESTRICT_EXCLUSIVE_OPENS = 0x00000100 + SMB2_SHAREFLAG_FORCE_SHARED_DELETE = 0x00000200 + SMB2_SHAREFLAG_ALLOW_NAMESPACE_CACHING = 0x00000400 + SMB2_SHAREFLAG_ACCESS_BASED_DIRECTORY_ENUM = 0x00000800 + SMB2_SHAREFLAG_FORCE_LEVELII_OPLOCK = 0x00001000 + SMB2_SHAREFLAG_ENABLE_HASH_V1 = 0x00002000 + SMB2_SHAREFLAG_ENABLE_HASH_V2 = 0x00004000 + SMB2_SHAREFLAG_ENCRYPT_DATA = 0x00008000 + SMB2_SHAREFLAG_IDENTITY_REMOTING = 0x00040000 + + +class ShareCapabilities(object): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.10 SMB2 TREE_CONNECT Response Capabilities + Indicates various capabilities for a share + """ + SMB2_SHARE_CAP_DFS = 0x00000008 + SMB2_SHARE_CAP_CONTINUOUS_AVAILABILITY = 0x00000010 + SMB2_SHARE_CAP_SCALEOUT = 0x00000020 + SMB2_SHARE_CAP_CLUSTER = 0x00000040 + SMB2_SHARE_CAP_ASYMMETRIC = 0x00000080 + SMB2_SHARE_CAP_REDIRECT_TO_OWNER = 0x00000100 + + +class SMB2TreeConnectRequest(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.9 SMB2 TREE_CONNECT Request + Sent by the client to request access to a particular share on the server + """ + COMMAND = Commands.SMB2_TREE_CONNECT + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=9 + )), + ('flags', FlagField( + size=2, + flag_type=TreeFlags, + )), + ('path_offset', IntField( + size=2, + default=64 + 8, + )), + ('path_length', IntField( + size=2, + default=lambda s: len(s['buffer']), + )), + ('buffer', BytesField( + size=lambda s: s['path_length'].get_value() + )) + ]) + super(SMB2TreeConnectRequest, self).__init__() + + +class SMB2TreeConnectResponse(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.10 SMB2 TREE_CONNECT Response + Sent by the server when an SMB2 TREE_CONNECT request is processed + successfully. + """ + COMMAND = Commands.SMB2_TREE_CONNECT + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=16 + )), + ('share_type', EnumField( + size=1, + enum_type=ShareType, + )), + ('reserved', IntField(size=1)), + ('share_flags', FlagField( + size=4, + flag_type=ShareFlags, + )), + ('capabilities', FlagField( + size=4, + flag_type=ShareCapabilities, + )), + ('maximal_access', IntField(size=4)) + ]) + super(SMB2TreeConnectResponse, self).__init__() + + +class SMB2TreeDisconnect(Structure): + """ + [MS-SMB2] v53.0 2017-09-15 + + 2.2.11/12 SMB2 TREE_DISCONNECT Request and Response + Sent by the client to request that the tree connect specific by tree_id in + the header is disconnected. + """ + COMMAND = Commands.SMB2_TREE_DISCONNECT + + def __init__(self): + self.fields = OrderedDict([ + ('structure_size', IntField( + size=2, + default=4, + )), + ('reserved', IntField(size=2)) + ]) + super(SMB2TreeDisconnect, self).__init__() + + +class TreeConnect(object): + + def __init__(self, session, share_name): + """ + [MS-SMB2] v53.0 2017-09-15 + + 3.2.1.4 Per Tree Connect + Attributes per Tree Connect (share connections) + + :param session: The Session to connect to the tree with + :param share_name: The name of the share, including the server name, + e.g. \\server\share + """ + self._connected = False + self.open_table = {} + + self.share_name = share_name + self.tree_connect_id = None + self.session = session + self.is_dfs_share = None + + # SMB 3.x+ + self.is_ca_share = None + self.encrypt_data = None + self.is_scaleout_share = None + + def connect(self, require_secure_negotiate=True): + """ + Connect to the share. + + :param require_secure_negotiate: For Dialects 3.0 and 3.0.2, will + verify the negotiation parameters with the server to prevent + SMB downgrade attacks + """ + log.info("Session: %s - Creating connection to share %s" + % (self.session.username, self.share_name)) + utf_share_name = self.share_name.encode('utf-16-le') + connect = SMB2TreeConnectRequest() + connect['buffer'] = utf_share_name + + log.info("Session: %s - Sending Tree Connect message" + % self.session.username) + log.debug(str(connect)) + request = self.session.connection.send(connect, + sid=self.session.session_id) + + log.info("Session: %s - Receiving Tree Connect response" + % self.session.username) + response = self.session.connection.receive(request) + tree_response = SMB2TreeConnectResponse() + tree_response.unpack(response['data'].get_value()) + log.debug(str(tree_response)) + + # https://msdn.microsoft.com/en-us/library/cc246687.aspx + self.tree_connect_id = response['tree_id'].get_value() + log.info("Session: %s - Created tree connection with ID %d" + % (self.session.username, self.tree_connect_id)) + self._connected = True + self.session.tree_connect_table[self.tree_connect_id] = self + + capabilities = tree_response['capabilities'] + self.is_dfs_share = capabilities.has_flag( + ShareCapabilities.SMB2_SHARE_CAP_DFS) + self.is_ca_share = capabilities.has_flag( + ShareCapabilities.SMB2_SHARE_CAP_CONTINUOUS_AVAILABILITY) + + dialect = self.session.connection.dialect + if dialect >= Dialects.SMB_3_0_0 and \ + self.session.connection.supports_encryption: + self.encrypt_data = tree_response['share_flags'].has_flag( + ShareFlags.SMB2_SHAREFLAG_ENCRYPT_DATA) + + self.is_scaleout_share = capabilities.has_flag( + ShareCapabilities.SMB2_SHARE_CAP_SCALEOUT) + + # secure negotiate is only valid for SMB 3 dialects before 3.1.1 + if dialect < Dialects.SMB_3_1_1 and require_secure_negotiate: + self._verify_dialect_negotiate() + + def disconnect(self): + """ + Disconnects the tree connection. + """ + if not self._connected: + return + + log.info("Session: %s, Tree: %s - Disconnecting from Tree Connect" + % (self.session.username, self.share_name)) + + req = SMB2TreeDisconnect() + log.info("Session: %s, Tree: %s - Sending Tree Disconnect message" + % (self.session.username, self.share_name)) + log.debug(str(req)) + request = self.session.connection.send(req, + sid=self.session.session_id, + tid=self.tree_connect_id) + + log.info("Session: %s, Tree: %s - Receiving Tree Disconnect response" + % (self.session.username, self.share_name)) + res = self.session.connection.receive(request) + res_disconnect = SMB2TreeDisconnect() + res_disconnect.unpack(res['data'].get_value()) + log.debug(str(res_disconnect)) + self._connected = False + del self.session.tree_connect_table[self.tree_connect_id] + + def _verify_dialect_negotiate(self): + log_header = "Session: %s, Tree: %s" \ + % (self.session.username, self.share_name) + log.info("%s - Running secure negotiate process" % log_header) + ioctl_request = SMB2IOCTLRequest() + ioctl_request['ctl_code'] = \ + CtlCode.FSCTL_VALIDATE_NEGOTIATE_INFO + ioctl_request['file_id'] = b"\xff" * 16 + + val_neg = SMB2ValidateNegotiateInfoRequest() + val_neg['capabilities'] = \ + self.session.connection.client_capabilities + val_neg['guid'] = self.session.connection.client_guid + val_neg['security_mode'] = \ + self.session.connection.client_security_mode + val_neg['dialects'] = \ + self.session.connection.negotiated_dialects + + ioctl_request['buffer'] = val_neg + ioctl_request['max_output_response'] = len(val_neg) + ioctl_request['flags'] = IOCTLFlags.SMB2_0_IOCTL_IS_FSCTL + log.info("%s - Sending Secure Negotiate Validation message" + % log_header) + log.debug(str(ioctl_request)) + request = self.session.connection.send(ioctl_request, + sid=self.session.session_id, + tid=self.tree_connect_id) + + log.info("%s - Receiving secure negotiation response" % log_header) + response = self.session.connection.receive(request) + ioctl_resp = SMB2IOCTLResponse() + ioctl_resp.unpack(response['data'].get_value()) + log.debug(str(ioctl_resp)) + + log.info("%s - Unpacking secure negotiate response info" % log_header) + val_resp = SMB2ValidateNegotiateInfoResponse() + val_resp.unpack(ioctl_resp['buffer'].get_value()) + log.debug(str(val_resp)) + + self._verify("server capabilities", + val_resp['capabilities'].get_value(), + self.session.connection.server_capabilities.get_value()) + self._verify("server guid", + val_resp['guid'].get_value(), + self.session.connection.server_guid) + self._verify("server security mode", + val_resp['security_mode'].get_value(), + self.session.connection.server_security_mode) + self._verify("server dialect", + val_resp['dialect'].get_value(), + self.session.connection.dialect) + log.info("Session: %d, Tree: %d - Secure negotiate complete" + % (self.session.session_id, self.tree_connect_id)) + + def _verify(self, check, actual, expected): + log_header = "Session: %d, Tree: %d"\ + % (self.session.session_id, self.tree_connect_id) + if actual != expected: + raise SMBException("%s - Secure negotiate failed to verify %s, " + "Actual: %s, Expected: %s" + % (log_header, check, actual, expected)) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 00000000..d16845b4 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,1022 @@ +import hashlib +import uuid + +import pytest +from cryptography.hazmat.primitives.ciphers import aead +from datetime import datetime +from smbprotocol.connection import Ciphers, Commands, Connection, Dialects, \ + HashAlgorithms, NegotiateContextType, SecurityMode, \ + SMB2EncryptionCapabilities, Smb2Flags, SMB2HeaderRequest, \ + SMB2HeaderResponse, SMB2NegotiateContextRequest, SMB2NegotiateRequest, \ + SMB2NegotiateResponse, SMB2PreauthIntegrityCapabilities, \ + SMB2TransformHeader, SMB3NegotiateRequest +from smbprotocol.ioctl import SMB2IOCTLRequest +from smbprotocol.exceptions import SMBException +from smbprotocol.session import Session + +from .utils import smb_real + + +def test_valid_hash_algorithm(): + expected = hashlib.sha512 + actual = HashAlgorithms.get_algorithm(0x1) + assert actual == expected + + +def test_invalid_hash_algorithm(): + with pytest.raises(KeyError) as exc: + HashAlgorithms.get_algorithm(0x2) + assert False # shouldn't be reached + + +def test_valid_cipher(): + expected = aead.AESCCM + actual = Ciphers.get_cipher(0x1) + assert actual == expected + + +def test_invalid_cipher(): + with pytest.raises(KeyError) as exc: + Ciphers.get_cipher(0x3) + assert False # shouldn't be reached + + +class TestSMB2HeaderRequest(object): + + def test_create_message(self): + header = SMB2HeaderRequest() + header['command'] = Commands.SMB2_SESSION_SETUP + header['message_id'] = 1 + header['process_id'] = 15 + header['session_id'] = 10 + expected = b"\xfe\x53\x4d\x42" \ + b"\x40\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0f\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + actual = header.pack() + assert len(header) == 64 + assert actual == expected + + def test_parse_message(self): + actual = SMB2HeaderRequest() + data = b"\xfe\x53\x4d\x42" \ + b"\x40\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0f\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual.unpack(data) + assert len(actual) == 68 + assert actual['protocol_id'].get_value() == b"\xfeSMB" + assert actual['structure_size'].get_value() == 64 + assert actual['credit_charge'].get_value() == 0 + assert actual['channel_sequence'].get_value() == 0 + assert actual['reserved'].get_value() == 0 + assert actual['command'].get_value() == Commands.SMB2_SESSION_SETUP + assert actual['credit_request'].get_value() == 0 + assert actual['flags'].get_value() == 0 + assert actual['next_command'].get_value() == 0 + assert actual['message_id'].get_value() == 1 + assert actual['process_id'].get_value() == 15 + assert actual['tree_id'].get_value() == 0 + assert actual['session_id'].get_value() == 10 + assert actual['signature'].get_value() == b"\x00" * 16 + assert actual['data'].get_value() == b"\x01\x02\x03\x04" + + +class TestSMB2HeaderResponse(object): + + def test_create_message(self): + header = SMB2HeaderResponse() + header['command'] = Commands.SMB2_SESSION_SETUP + header['message_id'] = 1 + header['session_id'] = 10 + expected = b"\xfe\x53\x4d\x42" \ + b"\x40\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + actual = header.pack() + assert len(header) == 64 + assert actual == expected + + def test_parse_message(self): + actual = SMB2HeaderResponse() + data = b"\xfe\x53\x4d\x42" \ + b"\x40\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual.unpack(data) + assert len(actual) == 68 + assert actual['protocol_id'].get_value() == b"\xfeSMB" + assert actual['structure_size'].get_value() == 64 + assert actual['credit_charge'].get_value() == 0 + assert actual['status'].get_value() == 0 + assert actual['command'].get_value() == Commands.SMB2_SESSION_SETUP + assert actual['credit_response'].get_value() == 0 + assert actual['flags'].get_value() == 0 + assert actual['next_command'].get_value() == 0 + assert actual['message_id'].get_value() == 1 + assert actual['reserved'].get_value() == 0 + assert actual['tree_id'].get_value() == 0 + assert actual['session_id'].get_value() == 10 + assert actual['signature'].get_value() == b"\x00" * 16 + assert actual['data'].get_value() == b"\x01\x02\x03\x04" + + +class TestSMB2NegotiateRequest(object): + + def test_create_message(self): + message = SMB2NegotiateRequest() + message['security_mode'] = SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + message['capabilities'] = 10 + message['client_guid'] = uuid.UUID(bytes=b"\x33" * 16) + message['dialects'] = [ + Dialects.SMB_2_0_2, + Dialects.SMB_2_1_0, + Dialects.SMB_3_0_0, + Dialects.SMB_3_0_2 + ] + expected = b"\x24\x00" \ + b"\x04\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x02\x02" \ + b"\x10\x02" \ + b"\x00\x03" \ + b"\x02\x03" + actual = message.pack() + assert len(message) == 44 + assert actual == expected + + def test_parse_message(self): + actual = SMB2NegotiateRequest() + data = b"\x24\x00" \ + b"\x04\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x02\x02" \ + b"\x10\x02" \ + b"\x00\x03" \ + b"\x02\x03" + actual.unpack(data) + assert len(actual) == 44 + assert actual['structure_size'].get_value() == 36 + assert actual['dialect_count'].get_value() == 4 + assert actual['security_mode'].get_value() == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + assert actual['reserved'].get_value() == 0 + assert actual['capabilities'].get_value() == 10 + assert actual['client_guid'].get_value() == \ + uuid.UUID(bytes=b"\x33" * 16) + assert actual['client_start_time'].get_value() == 0 + assert actual['dialects'].get_value() == [ + Dialects.SMB_2_0_2, + Dialects.SMB_2_1_0, + Dialects.SMB_3_0_0, + Dialects.SMB_3_0_2 + ] + + +class TestSMB3NegotiateRequest(object): + + def test_create_message(self): + message = SMB3NegotiateRequest() + message['security_mode'] = SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + message['capabilities'] = 10 + message['client_guid'] = uuid.UUID(bytes=b"\x33" * 16) + message['dialects'] = [ + Dialects.SMB_2_0_2, + Dialects.SMB_2_1_0, + Dialects.SMB_3_0_0, + Dialects.SMB_3_0_2, + Dialects.SMB_3_1_1 + ] + con_req = SMB2NegotiateContextRequest() + con_req['context_type'] = \ + NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES + + enc_cap = SMB2EncryptionCapabilities() + enc_cap['ciphers'] = [Ciphers.AES_128_GCM] + con_req['data'] = enc_cap + message['negotiate_context_list'] = [ + con_req + ] + expected = b"\x24\x00" \ + b"\x05\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x70\x00\x00\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x02\x02" \ + b"\x10\x02" \ + b"\x00\x03" \ + b"\x02\x03" \ + b"\x11\x03" \ + b"\x00\x00" \ + b"\x02\x00\x04\x00\x00\x00\x00\x00" \ + b"\x01\x00\x02\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 64 + assert actual == expected + + def test_create_message_one_dialect(self): + message = SMB3NegotiateRequest() + message['security_mode'] = SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + message['capabilities'] = 10 + message['client_guid'] = uuid.UUID(bytes=b"\x33" * 16) + message['dialects'] = [ + Dialects.SMB_3_1_1 + ] + con_req = SMB2NegotiateContextRequest() + con_req['context_type'] = \ + NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES + + enc_cap = SMB2EncryptionCapabilities() + enc_cap['ciphers'] = [Ciphers.AES_128_GCM] + con_req['data'] = enc_cap + message['negotiate_context_list'] = [ + con_req + ] + expected = b"\x24\x00" \ + b"\x01\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x68\x00\x00\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x11\x03" \ + b"\x00\x00" \ + b"\x02\x00\x04\x00\x00\x00\x00\x00" \ + b"\x01\x00\x02\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 56 + assert actual == expected + + def test_parse_message(self): + actual = SMB3NegotiateRequest() + data = b"\x24\x00" \ + b"\x05\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x33\x33\x33\x33\x33\x33\x33\x33" \ + b"\x70\x00\x00\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x02\x02" \ + b"\x10\x02" \ + b"\x00\x03" \ + b"\x02\x03" \ + b"\x11\x03" \ + b"\x00\x00" \ + b"\x02\x00\x04\x00\x00\x00\x00\x00" \ + b"\x01\x00\x02\x00" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 60 + assert actual['structure_size'].get_value() == 36 + assert actual['dialect_count'].get_value() == 5 + assert actual['security_mode'].get_value() == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + assert actual['reserved'].get_value() == 0 + assert actual['capabilities'].get_value() == 10 + assert actual['client_guid'].get_value() == \ + uuid.UUID(bytes=b"\x33" * 16) + assert actual['negotiate_context_offset'].get_value() == 112 + assert actual['negotiate_context_count'].get_value() == 1 + assert actual['reserved2'].get_value() == 0 + assert actual['dialects'].get_value() == [ + Dialects.SMB_2_0_2, + Dialects.SMB_2_1_0, + Dialects.SMB_3_0_0, + Dialects.SMB_3_0_2, + Dialects.SMB_3_1_1 + ] + assert actual['padding'].get_value() == b"\x00\x00" + + assert len(actual['negotiate_context_list'].get_value()) == 1 + neg_con = actual['negotiate_context_list'][0] + assert isinstance(neg_con, SMB2NegotiateContextRequest) + assert len(neg_con) == 12 + assert neg_con['context_type'].get_value() == \ + NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES + assert neg_con['data_length'].get_value() == 4 + assert neg_con['reserved'].get_value() == 0 + assert isinstance(neg_con['data'].get_value(), + SMB2EncryptionCapabilities) + assert neg_con['data']['cipher_count'].get_value() == 1 + assert neg_con['data']['ciphers'].get_value() == [Ciphers.AES_128_GCM] + + +class TestSMB2NegotiateContextRequest(object): + + def test_create_message(self): + message = SMB2NegotiateContextRequest() + message['context_type'] = \ + NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES + + enc_cap = SMB2EncryptionCapabilities() + enc_cap['ciphers'] = [Ciphers.AES_128_GCM] + message['data'] = enc_cap + expected = b"\x02\x00" \ + b"\x04\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00" \ + b"\x02\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 16 + assert actual == expected + + def test_parse_message(self): + actual = SMB2NegotiateContextRequest() + data = b"\x02\x00" \ + b"\x04\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00" \ + b"\x02\x00" + actual.unpack(data) + assert len(actual) == 12 + assert actual['context_type'].get_value() == \ + NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES + assert actual['data_length'].get_value() == 4 + assert actual['reserved'].get_value() == 0 + assert isinstance(actual['data'].get_value(), + SMB2EncryptionCapabilities) + assert actual['data']['cipher_count'].get_value() == 1 + assert actual['data']['ciphers'].get_value() == [Ciphers.AES_128_GCM] + + def test_parse_message_invalid_context_type(self): + actual = SMB2NegotiateContextRequest() + data = b"\x03\x00" \ + b"\x04\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00" \ + b"\x02\x00" + with pytest.raises(Exception) as exc: + actual.unpack(data) + assert str(exc.value) == "Enum value 3 does not exist in enum type " \ + "" + + +class TestSMB2PreauthIntegrityCapabilities(object): + + def test_create_message(self): + message = SMB2PreauthIntegrityCapabilities() + message['hash_algorithms'] = [ + HashAlgorithms.SHA_512 + ] + message['salt'] = b"\x01" * 16 + expected = b"\x01\x00" \ + b"\x10\x00" \ + b"\x01\x00" \ + b"\x01\x01\x01\x01\x01\x01\x01\x01" \ + b"\x01\x01\x01\x01\x01\x01\x01\x01" + actual = message.pack() + assert len(message) == 22 + assert actual == expected + + def test_parse_message(self): + actual = SMB2PreauthIntegrityCapabilities() + data = b"\x01\x00" \ + b"\x10\x00" \ + b"\x01\x00" \ + b"\x01\x01\x01\x01\x01\x01\x01\x01" \ + b"\x01\x01\x01\x01\x01\x01\x01\x01" + actual.unpack(data) + assert len(actual) == 22 + assert actual['hash_algorithm_count'].get_value() == 1 + assert actual['salt_length'].get_value() == 16 + assert actual['hash_algorithms'].get_value() == [ + HashAlgorithms.SHA_512 + ] + assert actual['salt'].get_value() == b"\x01" * 16 + + +class TestSMB2EncryptionCapabilities(object): + + def test_create_message(self): + message = SMB2EncryptionCapabilities() + message['ciphers'] = [ + Ciphers.AES_128_CCM, + Ciphers.AES_128_GCM + ] + expected = b"\x02\x00" \ + b"\x01\x00" \ + b"\x02\x00" + actual = message.pack() + assert len(message) == 6 + assert actual == expected + + def test_parse_message(self): + actual = SMB2EncryptionCapabilities() + data = b"\x02\x00" \ + b"\x01\x00" \ + b"\x02\x00" + actual.unpack(data) + assert len(actual) == 6 + assert actual['cipher_count'].get_value() == 2 + assert actual['ciphers'].get_value() == [ + Ciphers.AES_128_CCM, + Ciphers.AES_128_GCM + ] + + +class TestSMB2NegotiateResponse(object): + + def test_create_message(self): + message = SMB2NegotiateResponse() + message['security_mode'] = SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + message['dialect_revision'] = Dialects.SMB_3_0_2 + message['server_guid'] = uuid.UUID(bytes=b"\x11" * 16) + message['capabilities'] = 39 + message['max_transact_size'] = 8388608 + message['max_read_size'] = 8388608 + message['max_write_size'] = 8388608 + message['system_time'] = datetime( + year=2017, month=11, day=15, hour=11, minute=32, second=12, + microsecond=1616) + message['server_start_time'] = datetime( + year=2017, month=11, day=15, hour=11, minute=27, second=26, + microsecond=349606) + message['buffer'] = b"\x01\x02\x03\x04\x05\x06\x07\x08" \ + b"\x09\x10" + + expected = b"\x41\x00" \ + b"\x01\x00" \ + b"\x02\x03" \ + b"\x00\x00" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x27\x00\x00\x00" \ + b"\x00\x00\x80\x00" \ + b"\x00\x00\x80\x00" \ + b"\x00\x00\x80\x00" \ + b"\x20\xc5\x0d\x61\x05\x5e\xd3\x01" \ + b"\x7c\xbb\xca\xb6\x04\x5e\xd3\x01" \ + b"\x80\x00" \ + b"\x0a\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04\x05\x06\x07\x08" \ + b"\x09\x10" + actual = message.pack() + assert len(message) == 74 + assert actual == expected + + def test_create_message_3_1_1(self): + message = SMB2NegotiateResponse() + message['security_mode'] = SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + message['dialect_revision'] = Dialects.SMB_3_1_1 + message['server_guid'] = uuid.UUID(bytes=b"\x11" * 16) + message['capabilities'] = 39 + message['max_transact_size'] = 8388608 + message['max_read_size'] = 8388608 + message['max_write_size'] = 8388608 + message['system_time'] = datetime( + year=2017, month=11, day=15, hour=11, minute=32, second=12, + microsecond=1616) + message['server_start_time'] = datetime( + year=2017, month=11, day=15, hour=11, minute=27, second=26, + microsecond=349606) + message['buffer'] = b"\x01\x02\x03\x04\x05\x06\x07\x08" \ + b"\x09\x10" + + int_cap = SMB2PreauthIntegrityCapabilities() + int_cap['hash_algorithms'] = [HashAlgorithms.SHA_512] + int_cap['salt'] = b"\x22" * 32 + + negotiate_context = SMB2NegotiateContextRequest() + negotiate_context['context_type'] = \ + NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES + negotiate_context['data'] = int_cap + + message['negotiate_context_list'] = [negotiate_context] + expected = b"\x41\x00" \ + b"\x01\x00" \ + b"\x11\x03" \ + b"\x01\x00" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x27\x00\x00\x00" \ + b"\x00\x00\x80\x00" \ + b"\x00\x00\x80\x00" \ + b"\x00\x00\x80\x00" \ + b"\x20\xc5\x0d\x61\x05\x5e\xd3\x01" \ + b"\x7c\xbb\xca\xb6\x04\x5e\xd3\x01" \ + b"\x80\x00" \ + b"\x0a\x00" \ + b"\x90\x00\x00\x00" \ + b"\x01\x02\x03\x04\x05\x06\x07\x08" \ + b"\x09\x10" \ + b"\x00\x00\x00\x00\x00\x00" \ + b"\x01\x00\x26\x00\x00\x00\x00\x00" \ + b"\x01\x00\x20\x00\x01\x00\x22\x22" \ + b"\x22\x22\x22\x22\x22\x22\x22\x22" \ + b"\x22\x22\x22\x22\x22\x22\x22\x22" \ + b"\x22\x22\x22\x22\x22\x22\x22\x22" \ + b"\x22\x22\x22\x22\x22\x22" \ + b"\x00\x00" + actual = message.pack() + assert len(message) == 128 + assert actual == expected + + def test_parse_message(self): + actual = SMB2NegotiateResponse() + data = b"\x41\x00" \ + b"\x01\x00" \ + b"\x02\x03" \ + b"\x00\x00" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x67\x00\x00\x00" \ + b"\x00\x00\x80\x00" \ + b"\x00\x00\x80\x00" \ + b"\x00\x00\x80\x00" \ + b"\x14\x85\x12\x8b\xc2\x5e\xd3\x01" \ + b"\x04\x88\x4d\x21\xc2\x5e\xd3\x01" \ + b"\x80\x00" \ + b"\x78\x00" \ + b"\x00\x00\x00\x00" \ + b"\x60\x76\x06\x06\x2b\x06\x01\x05" \ + b"\x05\x02\xa0\x6c\x30\x6a\xa0\x3c" \ + b"\x30\x3a\x06\x0a\x2b\x06\x01\x04" \ + b"\x01\x82\x37\x02\x02\x1e\x06\x09" \ + b"\x2a\x86\x48\x82\xf7\x12\x01\x02" \ + b"\x02\x06\x09\x2a\x86\x48\x86\xf7" \ + b"\x12\x01\x02\x02\x06\x0a\x2a\x86" \ + b"\x48\x86\xf7\x12\x01\x02\x02\x03" \ + b"\x06\x0a\x2b\x06\x01\x04\x01\x82" \ + b"\x37\x02\x02\x0a\xa3\x2a\x30\x28" \ + b"\xa0\x26\x1b\x24\x6e\x6f\x74\x5f" \ + b"\x64\x65\x66\x69\x6e\x65\x64\x5f" \ + b"\x69\x6e\x5f\x52\x46\x43\x34\x31" \ + b"\x37\x38\x40\x70\x6c\x65\x61\x73" \ + b"\x65\x5f\x69\x67\x6e\x6f\x72\x65" + actual.unpack(data) + + assert len(actual) == 184 + assert actual['structure_size'].get_value() == 65 + + assert actual['security_mode'].get_value() == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + assert actual['dialect_revision'].get_value() == Dialects.SMB_3_0_2 + assert actual['negotiate_context_count'].get_value() == 0 + assert actual['server_guid'].get_value() == uuid.UUID( + bytes=b"\x11" * 16) + assert actual['capabilities'].get_value() == 103 + assert actual['max_transact_size'].get_value() == 8388608 + assert actual['max_read_size'].get_value() == 8388608 + assert actual['max_write_size'].get_value() == 8388608 + assert actual['system_time'].get_value() == datetime( + year=2017, month=11, day=16, hour=10, minute=6, second=17, + microsecond=378946) + assert actual['server_start_time'].get_value() == datetime( + year=2017, month=11, day=16, hour=10, minute=3, second=19, + microsecond=927194) + assert actual['security_buffer_offset'].get_value() == 128 + assert actual['security_buffer_length'].get_value() == 120 + assert actual['negotiate_context_offset'].get_value() == 0 + assert isinstance(actual['buffer'].get_value(), bytes) + assert len(actual['buffer']) == 120 + assert actual['padding'].get_value() == b"" + assert actual['negotiate_context_list'].get_value() == [] + + def test_parse_message_3_1_1(self): + actual = SMB2NegotiateResponse() + data = b"\x41\x00" \ + b"\x01\x00" \ + b"\x11\x03" \ + b"\x01\x00" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x27\x00\x00\x00" \ + b"\x00\x00\x80\x00" \ + b"\x00\x00\x80\x00" \ + b"\x00\x00\x80\x00" \ + b"\x24\xc5\x0d\x61\x05\x5e\xd3\x01" \ + b"\x7f\xbb\xca\xb6\x04\x5e\xd3\x01" \ + b"\x80\x00" \ + b"\x78\x00" \ + b"\xf8\x00\x00\x00" \ + b"\x60\x76\x06\x06\x2b\x06\x01\x05" \ + b"\x05\x02\xa0\x6c\x30\x6a\xa0\x3c" \ + b"\x30\x3a\x06\x0a\x2b\x06\x01\x04" \ + b"\x01\x82\x37\x02\x02\x1e\x06\x09" \ + b"\x2a\x86\x48\x82\xf7\x12\x01\x02" \ + b"\x02\x06\x09\x2a\x86\x48\x86\xf7" \ + b"\x12\x01\x02\x02\x06\x0a\x2a\x86" \ + b"\x48\x86\xf7\x12\x01\x02\x02\x03" \ + b"\x06\x0a\x2b\x06\x01\x04\x01\x82" \ + b"\x37\x02\x02\x0a\xa3\x2a\x30\x28" \ + b"\xa0\x26\x1b\x24\x6e\x6f\x74\x5f" \ + b"\x64\x65\x66\x69\x6e\x65\x64\x5f" \ + b"\x69\x6e\x5f\x52\x46\x43\x34\x31" \ + b"\x37\x38\x40\x70\x6c\x65\x61\x73" \ + b"\x65\x5f\x69\x67\x6e\x6f\x72\x65" \ + b"" \ + b"\x01\x00\x26\x00\x00\x00\x00\x00" \ + b"\x01\x00\x20\x00\x01\x00\x22\x22" \ + b"\x22\x22\x22\x22\x22\x22\x22\x22" \ + b"\x22\x22\x22\x22\x22\x22\x22\x22" \ + b"\x22\x22\x22\x22\x22\x22\x22\x22" \ + b"\x22\x22\x22\x22\x22\x22" + actual.unpack(data) + + assert len(actual) == 230 + assert actual['structure_size'].get_value() == 65 + + assert actual['security_mode'].get_value() == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + assert actual['dialect_revision'].get_value() == Dialects.SMB_3_1_1 + assert actual['negotiate_context_count'].get_value() == 1 + assert actual['server_guid'].get_value() == uuid.UUID( + bytes=b"\x11" * 16) + assert actual['capabilities'].get_value() == 39 + assert actual['max_transact_size'].get_value() == 8388608 + assert actual['max_read_size'].get_value() == 8388608 + assert actual['max_write_size'].get_value() == 8388608 + assert actual['system_time'].get_value() == datetime( + year=2017, month=11, day=15, hour=11, minute=32, second=12, + microsecond=1616) + assert actual['server_start_time'].get_value() == datetime( + year=2017, month=11, day=15, hour=11, minute=27, second=26, + microsecond=349606) + assert actual['security_buffer_offset'].get_value() == 128 + assert actual['security_buffer_length'].get_value() == 120 + assert actual['negotiate_context_offset'].get_value() == 248 + assert isinstance(actual['buffer'].get_value(), bytes) + assert len(actual['buffer']) == 120 + assert actual['padding'].get_value() == b"" + + assert isinstance(actual['negotiate_context_list'].get_value(), list) + assert len(actual['negotiate_context_list'].get_value()) == 1 + + neg_context = actual['negotiate_context_list'].get_value()[0] + assert isinstance(neg_context, SMB2NegotiateContextRequest) + assert neg_context['context_type'].get_value() == \ + NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES + assert neg_context['data_length'].get_value() == 38 + assert neg_context['reserved'].get_value() == 0 + + preauth_cap = neg_context['data'] + assert preauth_cap['hash_algorithm_count'].get_value() == 1 + assert preauth_cap['salt_length'].get_value() == 32 + assert preauth_cap['hash_algorithms'].get_value() == [ + HashAlgorithms.SHA_512 + ] + assert preauth_cap['salt'].get_value() == b"\x22" * 32 + + +class TestSMB2TransformHeader(object): + + def test_create_message(self): + message = SMB2TransformHeader() + message['nonce'] = b"\xff" * 16 + message['original_message_size'] = 4 + message['session_id'] = 1 + message['data'] = b"\x01\x02\x03\x04" + expected = b"\xfd\x53\x4d\x42" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00" \ + b"\x01\x00" \ + b"\x01\x00\x00\x00\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual = message.pack() + assert len(message) == 56 + assert actual == expected + + def test_parse_message(self): + actual = SMB2TransformHeader() + data = b"\xfd\x53\x4d\x42" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00" \ + b"\x01\x00" \ + b"\x01\x00\x00\x00\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual.unpack(data) + assert len(actual) == 56 + assert actual['protocol_id'].get_value() == b"\xfd\x53\x4d\x42" + assert actual['signature'].get_value() == b"\x00" * 16 + assert actual['nonce'].get_value() == b"\xff" * 16 + assert actual['original_message_size'].get_value() == 4 + assert actual['reserved'].get_value() == 0 + assert actual['flags'].get_value() == 1 + assert actual['session_id'].get_value() == 1 + assert actual['data'].get_value() == b"\x01\x02\x03\x04" + + +class TestConnection(object): + + def test_dialect_2_0_2(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_0_2) + try: + assert connection.dialect == Dialects.SMB_2_0_2 + assert connection.negotiated_dialects == [Dialects.SMB_2_0_2] + assert connection.gss_negotiate_token is not None + assert len(connection.preauth_integrity_hash_value) == 2 + assert len(connection.salt) == 32 + assert connection.sequence_window['low'] == 1 + assert connection.sequence_window['high'] == 2 + assert connection.client_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED + + # server settings override the require signing + assert connection.server_security_mode is None + assert not connection.supports_encryption + assert connection.require_signing + finally: + connection.disconnect() + + def test_dialect_2_1_0(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_1_0) + try: + assert connection.dialect == Dialects.SMB_2_1_0 + assert connection.negotiated_dialects == [Dialects.SMB_2_1_0] + assert connection.gss_negotiate_token is not None + assert len(connection.preauth_integrity_hash_value) == 2 + assert len(connection.salt) == 32 + assert connection.sequence_window['low'] == 1 + assert connection.sequence_window['high'] == 2 + assert connection.client_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED + + # server settings override the require signing + assert connection.server_security_mode is None + assert not connection.supports_encryption + assert connection.require_signing + finally: + connection.disconnect() + + def test_dialect_3_0_0(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_0) + try: + assert connection.dialect == Dialects.SMB_3_0_0 + assert connection.negotiated_dialects == [Dialects.SMB_3_0_0] + assert connection.gss_negotiate_token is not None + assert len(connection.preauth_integrity_hash_value) == 2 + assert len(connection.salt) == 32 + assert connection.sequence_window['low'] == 1 + assert connection.sequence_window['high'] == 2 + assert connection.client_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED + + # server settings override the require signing + assert connection.server_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + assert connection.supports_encryption + assert connection.require_signing + finally: + connection.disconnect() + + def test_dialect_3_0_2(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + try: + assert connection.dialect == Dialects.SMB_3_0_2 + assert connection.negotiated_dialects == [Dialects.SMB_3_0_2] + assert connection.gss_negotiate_token is not None + assert len(connection.preauth_integrity_hash_value) == 2 + assert len(connection.salt) == 32 + assert connection.sequence_window['low'] == 1 + assert connection.sequence_window['high'] == 2 + assert connection.client_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED + + # server settings override the require signing + assert connection.server_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + assert connection.supports_encryption + assert connection.require_signing + finally: + connection.disconnect() + + def test_dialect_3_1_1_not_require_signing(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3], False) + connection.connect(Dialects.SMB_3_1_1) + try: + assert connection.dialect == Dialects.SMB_3_1_1 + assert connection.negotiated_dialects == [Dialects.SMB_3_1_1] + assert connection.gss_negotiate_token is not None + assert len(connection.preauth_integrity_hash_value) == 2 + assert len(connection.salt) == 32 + assert connection.sequence_window['low'] == 1 + assert connection.sequence_window['high'] == 2 + assert connection.client_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + + # server settings override the require signing + assert connection.server_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + assert connection.supports_encryption + assert not connection.require_signing + finally: + connection.disconnect() + + def test_dialect_implicit_require_signing(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3], True) + connection.connect() + try: + assert connection.dialect == Dialects.SMB_3_1_1 + assert connection.negotiated_dialects == [ + Dialects.SMB_2_0_2, + Dialects.SMB_2_1_0, + Dialects.SMB_3_0_0, + Dialects.SMB_3_0_2, + Dialects.SMB_3_1_1 + ] + assert connection.gss_negotiate_token is not None + assert len(connection.preauth_integrity_hash_value) == 2 + assert len(connection.salt) == 32 + assert connection.sequence_window['low'] == 1 + assert connection.sequence_window['high'] == 2 + assert connection.client_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED + + # server settings override the require signing + assert connection.server_security_mode == \ + SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + assert connection.supports_encryption + assert connection.require_signing + finally: + connection.disconnect() + + def test_verify_message_skip(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3], True) + connection.connect() + try: + header = SMB2HeaderRequest() + header['message_id'] = 0xFFFFFFFFFFFFFFFF + expected = header.pack() + connection._verify(header) + actual = header.pack() + assert actual == expected + finally: + connection.disconnect() + + def test_verify_fail_no_session(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3], True) + connection.connect() + try: + header = SMB2HeaderRequest() + header['message_id'] = 1 + header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_SIGNED) + header['session_id'] = 100 + with pytest.raises(SMBException) as exc: + connection._verify(header) + assert str(exc.value) == "Failed to find session 100 for " \ + "message verification" + finally: + connection.disconnect() + + def test_verify_mistmatch(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3], True) + session = Session(connection, smb_real[0], smb_real[1]) + connection.connect() + try: + session.connect() + header = connection.preauth_integrity_hash_value[-1] + # just set some random values for verifiation failure + header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_SIGNED) + header['signature'] = b"\xff" * 16 + with pytest.raises(SMBException) as exc: + connection._verify(header, verify_session=True) + assert "Server message signature could not be verified:" in \ + str(exc.value) + finally: + connection.disconnect(True) + + def test_decrypt_invalid_flag(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3], True) + session = Session(connection, smb_real[0], smb_real[1]) + connection.connect() + try: + session.connect() + # just get some random message + header = connection.preauth_integrity_hash_value[-1] + enc_header = connection._encrypt(header.pack(), session) + assert isinstance(enc_header, SMB2TransformHeader) + enc_header['flags'] = 5 + with pytest.raises(SMBException) as exc: + connection._decrypt(enc_header) + assert str(exc.value) == "Expecting flag of 0x0001 but got 5 in " \ + "the SMB Transform Header Response" + finally: + connection.disconnect(True) + + def test_decrypt_invalid_session_id(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3], True) + session = Session(connection, smb_real[0], smb_real[1]) + connection.connect() + try: + session.connect() + # just get some random message + header = connection.preauth_integrity_hash_value[-1] + enc_header = connection._encrypt(header.pack(), session) + assert isinstance(enc_header, SMB2TransformHeader) + enc_header['session_id'] = 100 + with pytest.raises(SMBException) as exc: + connection._decrypt(enc_header) + assert str(exc.value) == "Failed to find valid session 100 for " \ + "message decryption" + finally: + connection.disconnect(True) + + def test_requested_credits_greater_than_available(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3], True) + connection.connect() + try: + msg = SMB2IOCTLRequest() + msg['max_output_response'] = 65538 # results in 2 credits required + with pytest.raises(SMBException) as exc: + connection._generate_packet_header(msg, None, None, 0) + assert str(exc.value) == "Request requires 2 credits but only 1 " \ + "credits are available" + finally: + connection.disconnect() + + def test_send_invalid_tree_id(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + session = Session(connection, smb_real[0], smb_real[1]) + connection.connect() + try: + session.connect() + msg = SMB2IOCTLRequest() + msg['file_id'] = b"\xff" * 16 + with pytest.raises(SMBException) as exc: + connection.send(msg, session.session_id, 10) + assert str(exc.value) == "Cannot find Tree with the ID 10 in " \ + "the session tree table" + finally: + connection.disconnect() diff --git a/tests/test_create_contexts.py b/tests/test_create_contexts.py new file mode 100644 index 00000000..02f4a2a3 --- /dev/null +++ b/tests/test_create_contexts.py @@ -0,0 +1,1064 @@ +import uuid + +from datetime import datetime +from smbprotocol.connection import NtStatus +from smbprotocol.create_contexts import CreateContextName, \ + DurableHandleFlags, LeaseRequestFlags, LeaseResponseFlags, LeaseState, \ + SMB2CreateAllocationSize, SMB2CreateAppInstanceId, \ + SMB2CreateAppInstanceVersion, SMB2CreateContextRequest, \ + SMB2CreateDurableHandleReconnect, SMB2CreateDurableHandleReconnectV2, \ + SMB2CreateDurableHandleRequest, SMB2CreateDurableHandleRequestV2, \ + SMB2CreateDurableHandleResponse, SMB2CreateDurableHandleResponseV2, \ + SMB2CreateEABuffer, SMB2CreateQueryMaximalAccessRequest, \ + SMB2CreateQueryMaximalAccessResponse, SMB2CreateQueryOnDiskIDResponse, \ + SMB2CreateRequestLease, SMB2CreateRequestLeaseV2, \ + SMB2CreateResponseLease, SMB2CreateResponseLeaseV2, \ + SMB2CreateTimewarpToken, SMB2SVHDXOpenDeviceContextRequest, \ + SMB2SVHDXOpenDeviceContextResponse, SMB2SVHDXOpenDeviceContextV2Request, \ + SMB2SVHDXOpenDeviceContextV2Response, SVHDXOriginatorFlags + + +class TestCreateContextName(object): + def test_get_response_known(self): + name = CreateContextName.SMB2_CREATE_QUERY_ON_DISK_ID + actual = CreateContextName.get_response_structure(name) + assert isinstance(actual, SMB2CreateQueryOnDiskIDResponse) + + def test_get_response_unknown(self): + name = CreateContextName.SMB2_CREATE_EA_BUFFER + expected = None + actual = CreateContextName.get_response_structure(name) + assert actual == expected + + +class TestSMB2CreateContextName(object): + + def test_create_message(self): + ea_buffer1 = SMB2CreateEABuffer() + ea_buffer1['ea_name'] = "Authors\x00".encode('ascii') + ea_buffer1['ea_value'] = "Jordan Borean".encode("utf-8") + + ea_buffer2 = SMB2CreateEABuffer() + ea_buffer2['ea_name'] = "Title\x00".encode('ascii') + ea_buffer2['ea_value'] = "Jordan Borean Title".encode('utf-8') + + ea_buffers = SMB2CreateContextRequest() + ea_buffers['buffer_name'] = CreateContextName.SMB2_CREATE_EA_BUFFER + ea_buffers['buffer_data'] = SMB2CreateEABuffer.pack_multiple([ + ea_buffer1, ea_buffer2 + ]) + + alloc_size = SMB2CreateAllocationSize() + alloc_size['allocation_size'] = 1024 + + alloc_size_context = SMB2CreateContextRequest() + alloc_size_context['buffer_name'] = \ + CreateContextName.SMB2_CREATE_ALLOCATION_SIZE + alloc_size_context['buffer_data'] = alloc_size + + query_disk = SMB2CreateContextRequest() + query_disk['buffer_name'] = \ + CreateContextName.SMB2_CREATE_QUERY_ON_DISK_ID + + expected = b"\x60\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x41\x00\x00\x00" \ + b"\x45\x78\x74\x41" \ + b"\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00" \ + b"\x07" \ + b"\x0d\x00" \ + b"\x41\x75\x74\x68\x6f\x72\x73\x00" \ + b"\x4a\x6f\x72\x64\x61\x6e\x20\x42" \ + b"\x6f\x72\x65\x61\x6e" \ + b"\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x05" \ + b"\x13\x00" \ + b"\x54\x69\x74\x6c\x65\x00" \ + b"\x4a\x6f\x72\x64\x61\x6e\x20\x42" \ + b"\x6f\x72\x65\x61\x6e\x20\x54\x69" \ + b"\x74\x6c\x65" \ + b"\x00\x00\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x08\x00\x00\x00" \ + b"\x41\x6c\x53\x69" \ + b"\x00\x00\x00\x00" \ + b"\x00\x04\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x51\x46\x69\x64" + + # has not padding on the end + assert len(ea_buffers) == 89 + assert len(alloc_size_context) == 32 + assert len(query_disk) == 20 + + actual = SMB2CreateContextRequest.pack_multiple([ + ea_buffers, + alloc_size_context, + query_disk + ]) + + # now has padding on the end + assert len(ea_buffers) == 96 + assert len(alloc_size_context) == 32 + assert len(query_disk) == 20 + assert actual == expected + + def test_parse_message(self): + actual1 = SMB2CreateContextRequest() + actual2 = SMB2CreateContextRequest() + actual3 = SMB2CreateContextRequest() + data = b"\x60\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x41\x00\x00\x00" \ + b"\x45\x78\x74\x41" \ + b"\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00" \ + b"\x07" \ + b"\x0d\x00" \ + b"\x41\x75\x74\x68\x6f\x72\x73\x00" \ + b"\x4a\x6f\x72\x64\x61\x6e\x20\x42" \ + b"\x6f\x72\x65\x61\x6e" \ + b"\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x05" \ + b"\x13\x00" \ + b"\x54\x69\x74\x6c\x65\x00" \ + b"\x4a\x6f\x72\x64\x61\x6e\x20\x42" \ + b"\x6f\x72\x65\x61\x6e\x20\x54\x69" \ + b"\x74\x6c\x65" \ + b"\x00\x00\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x08\x00\x00\x00" \ + b"\x41\x6c\x53\x69" \ + b"\x00\x00\x00\x00" \ + b"\x00\x04\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x51\x46\x69\x64" + data = actual1.unpack(data) + data = actual2.unpack(data) + data = actual3.unpack(data) + assert data == b"" + + assert len(actual1) == 96 + assert actual1['next'].get_value() == 96 + assert actual1['name_offset'].get_value() == 16 + assert actual1['name_length'].get_value() == 4 + assert actual1['reserved'].get_value() == 0 + assert actual1['data_offset'].get_value() == 24 + assert actual1['data_length'].get_value() == 65 + assert actual1['buffer_name'].get_value() == b"\x45\x78\x74\x41" + assert actual1['padding'].get_value() == b"\x00\x00\x00\x00" + + ea_buffer_data = actual1['buffer_data'].get_value() + actual_ea_buffer1 = SMB2CreateEABuffer() + actual_ea_buffer2 = SMB2CreateEABuffer() + ea_buffer_data = actual_ea_buffer1.unpack(ea_buffer_data) + ea_buffer_data = actual_ea_buffer2.unpack(ea_buffer_data) + assert ea_buffer_data == b"" + assert len(actual_ea_buffer1) == 32 + assert actual_ea_buffer1['next_entry_offset'].get_value() == 32 + assert actual_ea_buffer1['flags'].get_value() == 0 + assert actual_ea_buffer1['ea_name_length'].get_value() == 7 + assert actual_ea_buffer1['ea_value_length'].get_value() == 13 + assert actual_ea_buffer1['ea_name'].get_value() == \ + "Authors\x00".encode("ascii") + assert actual_ea_buffer1['ea_value'].get_value() == b"Jordan Borean" + assert actual_ea_buffer1['padding'].get_value() == b"\x00\x00\x00" + assert len(actual_ea_buffer2) == 33 + assert actual_ea_buffer2['next_entry_offset'].get_value() == 0 + assert actual_ea_buffer2['flags'].get_value() == 0 + assert actual_ea_buffer2['ea_name_length'].get_value() == 5 + assert actual_ea_buffer2['ea_value_length'].get_value() == 19 + assert actual_ea_buffer2['ea_name'].get_value() == \ + "Title\x00".encode("ascii") + assert actual_ea_buffer2['ea_value'].get_value() == \ + b"Jordan Borean Title" + assert actual_ea_buffer2['padding'].get_value() == b"" + + assert actual1['padding2'].get_value() == b"\x00" * 7 + + assert len(actual2) == 32 + assert actual2['next'].get_value() == 32 + assert actual2['name_offset'].get_value() == 16 + assert actual2['name_length'].get_value() == 4 + assert actual2['reserved'].get_value() == 0 + assert actual2['data_offset'].get_value() == 24 + assert actual2['data_length'].get_value() == 8 + assert actual2['buffer_name'].get_value() == b"\x41\x6c\x53\x69" + assert actual2['padding'].get_value() == b"\x00\x00\x00\x00" + alloc_data = actual2['buffer_data'].get_value() + alloc = SMB2CreateAllocationSize() + alloc_data = alloc.unpack(alloc_data) + assert alloc_data == b"" + assert alloc['allocation_size'].get_value() == 1024 + assert actual2['padding2'].get_value() == b"" + + assert len(actual3) == 20 + assert actual3['next'].get_value() == 0 + assert actual3['name_offset'].get_value() == 16 + assert actual3['name_length'].get_value() == 4 + assert actual3['reserved'].get_value() == 0 + assert actual3['data_offset'].get_value() == 0 + assert actual3['data_length'].get_value() == 0 + assert actual3['buffer_name'].get_value() == b"\x51\x46\x69\x64" + assert actual3['padding'].get_value() == b"" + assert actual3['buffer_data'].get_value() == b"" + assert actual3['padding2'].get_value() == b"" + + def test_get_context_data_known(self): + message = SMB2CreateContextRequest() + data = b"\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x20\x00\x00\x00" \ + b"\x51\x46\x69\x64" \ + b"\x00\x00\x00\x00" \ + b"\xed\x5a\x00\x00\x00\x00\x99\x00" \ + b"\x30\x50\xd7\xd8\x04\x82\xff\xff" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + message.unpack(data) + actual = message.get_context_data() + assert isinstance(actual, SMB2CreateQueryOnDiskIDResponse) + assert actual['disk_file_id'].get_value() == 43065671436753645 + assert actual['volume_id'].get_value() == 18446605556062310448 + assert actual['reserved'].get_value() == b"\x00" * 16 + + def test_get_context_data_unknown(self): + message = SMB2CreateContextRequest() + data = b"\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x04\x00\x00\x00" \ + b"\x45\x78\x74\x41" \ + b"\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" + message.unpack(data) + actual = message.get_context_data() + assert actual == b"\x20\x00\x00\x00" + + +class TestSMB2CreateEABuffer(object): + + def test_create_message(self): + msg1 = SMB2CreateEABuffer() + msg1['ea_name'] = "Authors\x00".encode('ascii') + msg1['ea_value'] = b"Jordan Borean" + + msg2 = SMB2CreateEABuffer() + msg2['ea_name'] = "Title\x00".encode("ascii") + msg2['ea_value'] = b"Jordan Borean Title" + + expected = b"\x20\x00\x00\x00" \ + b"\x00" \ + b"\x07" \ + b"\x0d\x00" \ + b"\x41\x75\x74\x68\x6f\x72\x73\x00" \ + b"\x4a\x6f\x72\x64\x61\x6e\x20\x42" \ + b"\x6f\x72\x65\x61\x6e" \ + b"\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x05" \ + b"\x13\x00" \ + b"\x54\x69\x74\x6c\x65\x00" \ + b"\x4a\x6f\x72\x64\x61\x6e\x20\x42" \ + b"\x6f\x72\x65\x61\x6e\x20\x54\x69" \ + b"\x74\x6c\x65" + + # size of msg1 won't have any padding as we haven't set the next offset + assert len(msg1) == 29 + assert len(msg2) == 33 + + # size of the padding changes in this argument as we add multiple + # together + actual = SMB2CreateEABuffer.pack_multiple([msg1, msg2]) + assert len(msg1) == 32 + assert len(msg2) == 33 + assert actual == expected + + def test_parse_message(self): + actual1 = SMB2CreateEABuffer() + actual2 = SMB2CreateEABuffer() + data = b"\x20\x00\x00\x00" \ + b"\x00" \ + b"\x07" \ + b"\x0d\x00" \ + b"\x41\x75\x74\x68\x6f\x72\x73\x00" \ + b"\x4a\x6f\x72\x64\x61\x6e\x20\x42" \ + b"\x6f\x72\x65\x61\x6e" \ + b"\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x05" \ + b"\x13\x00" \ + b"\x54\x69\x74\x6c\x65\x00" \ + b"\x4a\x6f\x72\x64\x61\x6e\x20\x42" \ + b"\x6f\x72\x65\x61\x6e\x20\x54\x69" \ + b"\x74\x6c\x65" + data = actual1.unpack(data) + data = actual2.unpack(data) + assert len(actual1) == 32 + assert actual1['next_entry_offset'].get_value() == 32 + assert actual1['flags'].get_value() == 0 + assert actual1['ea_name_length'].get_value() == 7 + assert actual1['ea_value_length'].get_value() == 13 + assert actual1['ea_name'].get_value() == "Authors\x00".encode("ascii") + assert actual1['ea_value'].get_value() == b"Jordan Borean" + assert actual1['padding'].get_value() == b"\x00\x00\x00" + assert len(actual2) == 33 + assert actual2['next_entry_offset'].get_value() == 0 + assert actual2['flags'].get_value() == 0 + assert actual2['ea_name_length'].get_value() == 5 + assert actual2['ea_value_length'].get_value() == 19 + assert actual2['ea_name'].get_value() == "Title\x00".encode("ascii") + assert actual2['ea_value'].get_value() == b"Jordan Borean Title" + assert actual2['padding'].get_value() == b"" + + +class TestSMB2CreateDurableHandleRequest(object): + + def test_create_message(self): + message = SMB2CreateDurableHandleRequest() + expected = b"\x00" * 16 + actual = message.pack() + assert len(message) == 16 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateDurableHandleRequest() + data = b"\x00" * 16 + data = actual.unpack(data) + assert len(actual) == 16 + assert data == b"" + assert actual['durable_request'].get_value() == b"\x00" * 16 + + +class TestSMB2CreateDurableHandleResponse(object): + + def test_create_message(self): + message = SMB2CreateDurableHandleResponse() + expected = b"\x00" * 8 + actual = message.pack() + assert len(message) == 8 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateDurableHandleResponse() + data = b"\x00" * 8 + data = actual.unpack(data) + assert len(actual) == 8 + assert data == b"" + assert actual['reserved'].get_value() == 0 + + +class TestSMB2CreateDurableHandleReconnect(object): + + def test_create_message(self): + message = SMB2CreateDurableHandleReconnect() + message['data'] = b"\xff" * 16 + expected = b"\xff" * 16 + actual = message.pack() + assert len(message) == 16 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateDurableHandleReconnect() + data = b"\xff" * 16 + data = actual.unpack(data) + assert len(actual) == 16 + assert data == b"" + assert actual['data'].pack() == b"\xff" * 16 + + +class TestSMB2CreateQueryMaximalAccessRequest(object): + + def test_create_message(self): + message = SMB2CreateQueryMaximalAccessRequest() + message['timestamp'] = datetime.utcfromtimestamp(0) + expected = b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + actual = message.pack() + assert len(message) == 8 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateQueryMaximalAccessRequest() + data = b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + data = actual.unpack(data) + assert len(actual) == 8 + assert data == b"" + assert actual['timestamp'].get_value() == datetime.utcfromtimestamp(0) + + +class TestSMB2CreateQueryMaximalAccessResponse(object): + + def test_create_message(self): + message = SMB2CreateQueryMaximalAccessResponse() + message['maximal_access'] = 2032127 + expected = b"\x00\x00\x00\x00" \ + b"\xff\x01\x1f\x00" + actual = message.pack() + assert len(message) == 8 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateQueryMaximalAccessResponse() + data = b"\x00\x00\x00\x00" \ + b"\xff\x01\x1f\x00" + data = actual.unpack(data) + assert len(actual) == 8 + assert data == b"" + assert actual['query_status'].get_value() == NtStatus.STATUS_SUCCESS + assert actual['maximal_access'].get_value() == 2032127 + + +class TestSMB2CreateAllocationSize(object): + + def test_create_message(self): + message = SMB2CreateAllocationSize() + message['allocation_size'] = 1024 + expected = b"\x00\x04\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 8 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateAllocationSize() + data = b"\x00\x04\x00\x00\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 8 + assert data == b"" + assert actual['allocation_size'].get_value() == 1024 + + +class TestSMB2CreateTimewarpToken(object): + + def test_create_message(self): + message = SMB2CreateTimewarpToken() + message['timestamp'] = datetime.utcfromtimestamp(0) + expected = b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + actual = message.pack() + assert len(message) == 8 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateTimewarpToken() + data = b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + data = actual.unpack(data) + assert len(actual) == 8 + assert data == b"" + assert actual['timestamp'].get_value() == datetime.utcfromtimestamp(0) + + +class TestSMB2CreateRequestLease(object): + + def test_create_message(self): + message = SMB2CreateRequestLease() + message['lease_key'] = b"\xff" * 16 + message['lease_state'].set_flag(LeaseState.SMB2_LEASE_HANDLE_CACHING) + message['lease_duration'] = 10 + expected = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 32 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateRequestLease() + data = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 32 + assert data == b"" + assert actual['lease_key'].get_value() == b"\xff" * 16 + assert actual['lease_state'].get_value() == \ + LeaseState.SMB2_LEASE_HANDLE_CACHING + assert actual['lease_flags'].get_value() == 0 + assert actual['lease_duration'].get_value() == 10 + + +class TestSMB2CreateResponseLease(object): + + def test_create_message(self): + message = SMB2CreateResponseLease() + message['lease_key'] = b"\xff" * 16 + message['lease_state'].set_flag(LeaseState.SMB2_LEASE_HANDLE_CACHING) + message['lease_flags'].set_flag( + LeaseResponseFlags.SMB2_LEASE_FLAG_BREAK_IN_PROGRESS + ) + message['lease_duration'] = 12 + expected = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x02\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x0c\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 32 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateResponseLease() + data = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x02\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x0c\x00\x00\x00\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 32 + assert data == b"" + assert actual['lease_key'].get_value() == b"\xff" * 16 + assert actual['lease_state'].get_value() == \ + LeaseState.SMB2_LEASE_HANDLE_CACHING + assert actual['lease_flags'].get_value() == \ + LeaseResponseFlags.SMB2_LEASE_FLAG_BREAK_IN_PROGRESS + assert actual['lease_duration'].get_value() == 12 + + +class TestSMB2CreateQueryOnDiskIDResponse(object): + + def test_create_message(self): + message = SMB2CreateQueryOnDiskIDResponse() + message['disk_file_id'] = 43065671436753645 + message['volume_id'] = 18446605556062310448 + expected = b"\xed\x5a\x00\x00\x00\x00\x99\x00" \ + b"\x30\x50\xd7\xd8\x04\x82\xff\xff" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 32 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateQueryOnDiskIDResponse() + data = b"\xed\x5a\x00\x00\x00\x00\x99\x00" \ + b"\x30\x50\xd7\xd8\x04\x82\xff\xff" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 32 + assert data == b"" + assert actual['disk_file_id'].get_value() == 43065671436753645 + assert actual['volume_id'].get_value() == 18446605556062310448 + assert actual['reserved'].get_value() == b"\x00" * 16 + + +class TestSMB2CreateRequestLeaseV2(object): + + def test_create_message(self): + message = SMB2CreateRequestLeaseV2() + message['lease_key'] = b"\xff" * 16 + message['lease_state'] = LeaseState.SMB2_LEASE_READ_CACHING + message['lease_flags'] = \ + LeaseRequestFlags.SMB2_LEASE_FLAG_PARENT_LEASE_KEY_SET + message['lease_duration'] = 10 + message['parent_lease_key'] = b"\xee" * 16 + message['epoch'] = b"\xdd" * 16 + expected = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x01\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\xdd\xdd\xdd\xdd\xdd\xdd\xdd\xdd" \ + b"\xdd\xdd\xdd\xdd\xdd\xdd\xdd\xdd" \ + b"\x00\x00" + actual = message.pack() + assert len(message) == 66 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateRequestLeaseV2() + data = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x01\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\xdd\xdd\xdd\xdd\xdd\xdd\xdd\xdd" \ + b"\xdd\xdd\xdd\xdd\xdd\xdd\xdd\xdd" \ + b"\x00\x00" + data = actual.unpack(data) + assert len(actual) == 66 + assert data == b"" + assert actual['lease_key'].get_value() == b"\xff" * 16 + assert actual['lease_state'].get_value() == \ + LeaseState.SMB2_LEASE_READ_CACHING + assert actual['lease_flags'].get_value() == \ + LeaseRequestFlags.SMB2_LEASE_FLAG_PARENT_LEASE_KEY_SET + assert actual['lease_duration'].get_value() == 10 + assert actual['parent_lease_key'].get_value() == b"\xee" * 16 + assert actual['epoch'].get_value() == b"\xdd" * 16 + assert actual['reserved'].get_value() == 0 + + +class TestSMB2CreateResponseLeaseV2(object): + + def test_create_message(self): + message = SMB2CreateResponseLeaseV2() + message['lease_key'] = b"\xff" * 16 + message['lease_state'] = LeaseState.SMB2_LEASE_READ_CACHING + message['flags'] = \ + LeaseRequestFlags.SMB2_LEASE_FLAG_PARENT_LEASE_KEY_SET + message['lease_duration'] = 10 + message['parent_lease_key'] = b"\xee" * 16 + message['epoch'] = 100 + expected = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x01\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\x64\x00" \ + b"\x00\x00" + actual = message.pack() + assert len(message) == 52 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateResponseLeaseV2() + data = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x01\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\x64\x00" \ + b"\x00\x00" + data = actual.unpack(data) + assert len(actual) == 52 + assert data == b"" + assert actual['lease_key'].get_value() == b"\xff" * 16 + assert actual['lease_state'].get_value() == \ + LeaseState.SMB2_LEASE_READ_CACHING + assert actual['flags'].get_value() == \ + LeaseRequestFlags.SMB2_LEASE_FLAG_PARENT_LEASE_KEY_SET + assert actual['lease_duration'].get_value() == 10 + assert actual['parent_lease_key'].get_value() == b"\xee" * 16 + assert actual['epoch'].get_value() == 100 + assert actual['reserved'].get_value() == 0 + + +class TestSMB2CreateDurableHandleRequestV2(object): + + def test_create_message(self): + message = SMB2CreateDurableHandleRequestV2() + message['timeout'] = 100 + message['flags'] = DurableHandleFlags.SMB2_DHANDLE_FLAG_PERSISTENT + message['create_guid'] = b"\xff" * 16 + expected = b"\x64\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" + actual = message.pack() + assert len(message) == 32 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateDurableHandleRequestV2() + data = b"\x64\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" + data = actual.unpack(data) + assert len(actual) == 32 + assert data == b"" + assert actual['timeout'].get_value() == 100 + assert actual['flags'].get_value() == \ + DurableHandleFlags.SMB2_DHANDLE_FLAG_PERSISTENT + assert actual['reserved'].get_value() == 0 + assert actual['create_guid'].get_value() == \ + uuid.UUID(bytes=b"\xff" * 16) + + +class TestSMB2CreateDurableHandleReconnectV2(object): + + def test_create_message(self): + message = SMB2CreateDurableHandleReconnectV2() + message['file_id'] = b"\xff" * 16 + message['create_guid'] = b"\xee" * 16 + message['flags'] = DurableHandleFlags.SMB2_DHANDLE_FLAG_PERSISTENT + expected = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\x02\x00\x00\x00" + actual = message.pack() + assert len(message) == 36 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateDurableHandleReconnectV2() + data = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\xee\xee\xee\xee\xee\xee\xee\xee" \ + b"\x02\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 36 + assert data == b"" + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['create_guid'].get_value() == \ + uuid.UUID(bytes=b"\xee" * 16) + assert actual['flags'].get_value() == \ + DurableHandleFlags.SMB2_DHANDLE_FLAG_PERSISTENT + + +class TestSMB2CreateDurableHandleResponseV2(object): + + def test_create_message(self): + message = SMB2CreateDurableHandleResponseV2() + message['timeout'] = 10 + expected = b"\x0a\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 8 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateDurableHandleResponseV2() + data = b"\x0a\x00\x00\x00" \ + b"\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 8 + assert data == b"" + assert actual['timeout'].get_value() == 10 + assert actual['flags'].get_value() == 0 + + +class TestSMB2CreateAppInstanceId(object): + + def test_create_message(self): + message = SMB2CreateAppInstanceId() + message['app_instance_id'] = b"\xff" * 16 + expected = b"\x14\x00" \ + b"\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" + actual = message.pack() + assert len(message) == 20 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateAppInstanceId() + data = b"\x14\x00" \ + b"\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" + data = actual.unpack(data) + assert len(actual) == 20 + assert data == b"" + assert actual['structure_size'].get_value() == 20 + assert actual['reserved'].get_value() == 0 + assert actual['app_instance_id'].get_value() == b"\xff" * 16 + + +class TestSMB2SVHDXOpenDeviceContextRequest(object): + + def test_create_message(self): + message = SMB2SVHDXOpenDeviceContextRequest() + message['initiator_id'] = b"\xff" * 16 + message['originator_flags'] = \ + SVHDXOriginatorFlags.SVHDX_ORIGINATOR_VHDMP + message['open_request_id'] = 5 + message['initiator_host_name'] = "hostname".encode('utf-16-le') + expected = b"\x01\x00\x00\x00" \ + b"\x01" \ + b"\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x04\x00\x00\x00" \ + b"\x05\x00\x00\x00\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x68\x00\x6f\x00\x73\x00\x74\x00" \ + b"\x6e\x00\x61\x00\x6d\x00\x65\x00" + actual = message.pack() + assert len(message) == 54 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SVHDXOpenDeviceContextRequest() + data = b"\x01\x00\x00\x00" \ + b"\x01" \ + b"\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x04\x00\x00\x00" \ + b"\x05\x00\x00\x00\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x68\x00\x6f\x00\x73\x00\x74\x00" \ + b"\x6e\x00\x61\x00\x6d\x00\x65\x00" + data = actual.unpack(data) + assert len(actual) == 54 + assert data == b"" + assert actual['version'].get_value() == 1 + assert actual['has_initiator_id'].get_value() + assert actual['reserved'].get_value() == b"\x00\x00\x00" + assert actual['initiator_id'].get_value() == \ + uuid.UUID(bytes=b"\xff" * 16) + assert actual['originator_flags'].get_value() == \ + SVHDXOriginatorFlags.SVHDX_ORIGINATOR_VHDMP + assert actual['open_request_id'].get_value() == 5 + assert actual['initiator_host_name_length'].get_value() == 16 + assert actual['initiator_host_name'].get_value() == \ + "hostname".encode("utf-16-le") + + +class TestSMB2SVHDXOpenDeviceContextResponse(object): + + def test_create_message(self): + message = SMB2SVHDXOpenDeviceContextResponse() + message['initiator_id'] = b"\xff" * 16 + message['open_request_id'] = 20 + message['initiator_host_name'] = "hostname".encode("utf-16-le") + expected = b"\x01\x00\x00\x00" \ + b"\x01" \ + b"\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x68\x00\x6f\x00\x73\x00\x74\x00" \ + b"\x6e\x00\x61\x00\x6d\x00\x65\x00" + actual = message.pack() + assert len(message) == 58 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SVHDXOpenDeviceContextResponse() + data = b"\x01\x00\x00\x00" \ + b"\x01" \ + b"\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x68\x00\x6f\x00\x73\x00\x74\x00" \ + b"\x6e\x00\x61\x00\x6d\x00\x65\x00" + data = actual.unpack(data) + assert len(actual) == 58 + assert data == b"" + assert actual['version'].get_value() == 1 + assert actual['has_initiator_id'].get_value() + assert actual['reserved'].get_value() == b"\x00\x00\x00" + assert actual['initiator_id'].get_value() == \ + uuid.UUID(bytes=b"\xff" * 16) + assert actual['flags'].get_value() == 0 + assert actual['originator_flags'].get_value() == 0 + assert actual['open_request_id'].get_value() == 20 + assert actual['initiator_host_name_length'].get_value() == 16 + assert actual['initiator_host_name'].get_value() == \ + "hostname".encode("utf-16-le") + + +class TestSMB2SVHDXOpenDeviceContextV2Request(object): + + def test_create_message(self): + message = SMB2SVHDXOpenDeviceContextV2Request() + message['initiator_id'] = b"\xff" * 16 + message['originator_flags'] = \ + SVHDXOriginatorFlags.SVHDX_ORIGINATOR_VHDMP + message['open_request_id'] = 5 + message['initiator_host_name'] = "hostname".encode('utf-16-le') + expected = b"\x02\x00\x00\x00" \ + b"\x01" \ + b"\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x04\x00\x00\x00" \ + b"\x05\x00\x00\x00\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x68\x00\x6f\x00\x73\x00\x74\x00" \ + b"\x6e\x00\x61\x00\x6d\x00\x65\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 78 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SVHDXOpenDeviceContextV2Request() + data = b"\x02\x00\x00\x00" \ + b"\x01" \ + b"\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x04\x00\x00\x00" \ + b"\x05\x00\x00\x00\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x68\x00\x6f\x00\x73\x00\x74\x00" \ + b"\x6e\x00\x61\x00\x6d\x00\x65\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 78 + assert data == b"" + assert actual['version'].get_value() == 2 + assert actual['has_initiator_id'].get_value() + assert actual['reserved'].get_value() == b"\x00\x00\x00" + assert actual['initiator_id'].get_value() == \ + uuid.UUID(bytes=b"\xff" * 16) + assert actual['originator_flags'].get_value() == \ + SVHDXOriginatorFlags.SVHDX_ORIGINATOR_VHDMP + assert actual['open_request_id'].get_value() == 5 + assert actual['initiator_host_name_length'].get_value() == 16 + assert actual['initiator_host_name'].get_value() == \ + "hostname".encode("utf-16-le") + assert actual['virtual_disk_properties_initialized'].get_value() == 0 + assert actual['server_service_version'].get_value() == 0 + assert actual['virtual_sector_size'].get_value() == 0 + assert actual['physical_sector_size'].get_value() == 0 + assert actual['virtual_size'].get_value() == 0 + + +class TestSMB2SVHDXOpenDeviceContextV2Response(object): + + def test_create_message(self): + message = SMB2SVHDXOpenDeviceContextV2Response() + message['initiator_id'] = b"\xff" * 16 + message['open_request_id'] = 20 + message['initiator_host_name'] = "hostname".encode("utf-16-le") + expected = b"\x02\x00\x00\x00" \ + b"\x01" \ + b"\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x68\x00\x6f\x00\x73\x00\x74\x00" \ + b"\x6e\x00\x61\x00\x6d\x00\x65\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 82 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SVHDXOpenDeviceContextV2Response() + data = b"\x02\x00\x00\x00" \ + b"\x01" \ + b"\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x68\x00\x6f\x00\x73\x00\x74\x00" \ + b"\x6e\x00\x61\x00\x6d\x00\x65\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 82 + assert data == b"" + assert actual['version'].get_value() == 2 + assert actual['has_initiator_id'].get_value() + assert actual['reserved'].get_value() == b"\x00\x00\x00" + assert actual['initiator_id'].get_value() == \ + uuid.UUID(bytes=b"\xff" * 16) + assert actual['flags'].get_value() == 0 + assert actual['originator_flags'].get_value() == 0 + assert actual['open_request_id'].get_value() == 20 + assert actual['initiator_host_name_length'].get_value() == 16 + assert actual['initiator_host_name'].get_value() == \ + "hostname".encode("utf-16-le") + assert actual['virtual_disk_properties_initialized'].get_value() == 0 + assert actual['server_service_version'].get_value() == 0 + assert actual['virtual_sector_size'].get_value() == 0 + assert actual['physical_sector_size'].get_value() == 0 + assert actual['virtual_size'].get_value() == 0 + + +class TestSMB2CreateAppInstanceVersion(object): + + def test_create_message(self): + message = SMB2CreateAppInstanceVersion() + message['app_instance_version_high'] = 10 + message['app_instance_version_low'] = 10 + expected = b"\x18\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 24 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateAppInstanceVersion() + data = b"\x18\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 24 + assert data == b"" + assert actual['structure_size'].get_value() == 24 + assert actual['reserved'].get_value() == 0 + assert actual['padding'].get_value() == 0 + assert actual['app_instance_version_high'].get_value() == 10 + assert actual['app_instance_version_low'].get_value() == 10 diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 00000000..39068a82 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,509 @@ +import pytest + +from smbprotocol.connection import SMB2HeaderResponse, NtStatus, Dialects +from smbprotocol.exceptions import ErrorContextId, IpAddrType, \ + SMBAuthenticationError, SMBException, SMBResponseException, \ + SMBUnsupportedFeature, SMB2ErrorContextResponse, SMB2ErrorResponse, \ + SMB2MoveDstIpAddrStructure, SMB2ShareRedirectErrorContext, \ + SMB2SymbolicLinkErrorResponse, SymbolicLinkErrorFlags + + +class TestSMBException(object): + + def test_exception(self): + with pytest.raises(SMBException) as exc: + raise SMBException("smb error") + assert str(exc.value) == "smb error" + + +class TestSMBAuthenticationError(object): + + def test_exception(self): + with pytest.raises(SMBAuthenticationError) as exc: + raise SMBAuthenticationError("auth error") + assert str(exc.value) == "auth error" + + def test_caught_with_smbexception(self): + with pytest.raises(SMBException) as exc: + raise SMBAuthenticationError("auth error") + assert str(exc.value) == "auth error" + + +class TestSMBUnsupportedFeature(object): + + def test_exception_needs_newer(self): + with pytest.raises(SMBUnsupportedFeature) as exc: + raise SMBUnsupportedFeature(Dialects.SMB_3_0_0, Dialects.SMB_3_1_1, + "feature", True) + assert str(exc.value) == "feature is not available on the " \ + "negotiated dialect (768) SMB_3_0_0, " \ + "requires dialect (785) SMB_3_1_1 or newer" + + def test_exception_needs_older(self): + with pytest.raises(SMBUnsupportedFeature) as exc: + raise SMBUnsupportedFeature(Dialects.SMB_3_0_0, Dialects.SMB_3_1_1, + "feature", False) + assert str(exc.value) == "feature is not available on the " \ + "negotiated dialect (768) SMB_3_0_0, " \ + "requires dialect (785) SMB_3_1_1 or older" + + def test_exception_no_suffix(self): + with pytest.raises(SMBUnsupportedFeature) as exc: + raise SMBUnsupportedFeature(Dialects.SMB_3_0_0, Dialects.SMB_3_1_1, + "feature") + assert str(exc.value) == "feature is not available on the " \ + "negotiated dialect (768) SMB_3_0_0, " \ + "requires dialect (785) SMB_3_1_1" + + +class TestSMBResponseException(object): + + def test_throw_default_exception(self): + error_resp = SMB2ErrorResponse() + header = self._get_header(error_resp) + try: + raise SMBResponseException(header, header['status'].get_value(), + header['message_id'].get_value()) + except SMBResponseException as exc: + assert exc.error_details == [] + exp_resp = "Received unexpected status from the server: " \ + "(3221225485) STATUS_INVALID_PARAMETER: 0xc000000d" + assert exc.message == exp_resp + assert str(exc) == exp_resp + assert exc.status == NtStatus.STATUS_INVALID_PARAMETER + + def test_throw_exception_with_symlink_redir(self): + symlnk_redir = SMB2SymbolicLinkErrorResponse() + symlnk_redir.set_name(r"C:\temp\folder", r"\??\C:\temp\folder") + + cont_resp = SMB2ErrorContextResponse() + cont_resp['error_context_data'] = symlnk_redir + + error_resp = SMB2ErrorResponse() + error_resp['error_data'] = [cont_resp] + header = self._get_header(error_resp, + NtStatus.STATUS_STOPPED_ON_SYMLINK) + try: + raise SMBResponseException(header, header['status'].get_value(), + header['message_id'].get_value()) + except SMBResponseException as exc: + assert len(exc.error_details) == 1 + err1 = exc.error_details[0] + assert isinstance(err1, SMB2SymbolicLinkErrorResponse) + exp_resp = "Received unexpected status from the server: " \ + "(2147483693) STATUS_STOPPED_ON_SYMLINK: 0x8000002d " \ + "- Flag: (0) SYMLINK_FLAG_ABSOLUTE, " \ + r"Print Name: C:\temp\folder, " \ + r"Substitute Name: \??\C:\temp\folder" + assert exc.message == exp_resp + assert str(exc) == exp_resp + assert exc.status == NtStatus.STATUS_STOPPED_ON_SYMLINK + + def test_throw_exception_with_share_redir(self): + ip_addr = SMB2MoveDstIpAddrStructure() + ip_addr['type'] = IpAddrType.MOVE_DST_IPADDR_V4 + ip_addr.set_ipaddress("192.168.1.100") + + share_redir = SMB2ShareRedirectErrorContext() + share_redir['ip_addr_move_list'] = [ip_addr] + share_redir['resource_name'] = "resource".encode('utf-16-le') + + cont_resp = SMB2ErrorContextResponse() + cont_resp['error_id'] = ErrorContextId.SMB2_ERROR_ID_SHARE_REDIRECT + cont_resp['error_context_data'] = share_redir + + error_resp = SMB2ErrorResponse() + error_resp['error_data'] = [cont_resp] + header = self._get_header(error_resp, + NtStatus.STATUS_BAD_NETWORK_NAME) + try: + raise SMBResponseException(header, header['status'].get_value(), + header['message_id'].get_value()) + except SMBResponseException as exc: + assert len(exc.error_details) == 1 + err1 = exc.error_details[0] + assert isinstance(err1, SMB2ShareRedirectErrorContext) + exp_resp = "Received unexpected status from the server: " \ + "(3221225676) STATUS_BAD_NETWORK_NAME: 0xc00000cc - " \ + "IP Addresses: '192.168.1.100', Resource Name: resource" + assert exc.message == exp_resp + assert str(exc) == exp_resp + assert exc.status == NtStatus.STATUS_BAD_NETWORK_NAME + + def test_throw_exception_with_raw_context(self): + error_resp = SMB2ErrorResponse() + cont_resp = SMB2ErrorContextResponse() + cont_resp['error_context_data'] = b"\x01\x02\x03\x04" + error_resp['error_data'] = [cont_resp] + header = self._get_header(error_resp) + try: + raise SMBResponseException(header, header['status'].get_value(), + header['message_id'].get_value()) + except SMBResponseException as exc: + assert len(exc.error_details) == 1 + assert exc.error_details[0] == b"\x01\x02\x03\x04" + exp_resp = "Received unexpected status from the server: " \ + "(3221225485) STATUS_INVALID_PARAMETER: 0xc000000d - " \ + "Raw: 01020304" + assert exc.message == exp_resp + assert str(exc) == exp_resp + assert exc.status == NtStatus.STATUS_INVALID_PARAMETER + + def test_throw_exception_with_multiple_contexts(self): + error_resp = SMB2ErrorResponse() + cont_resp1 = SMB2ErrorContextResponse() + cont_resp1['error_context_data'] = b"\x01\x02\x03\x04" + cont_resp2 = SMB2ErrorContextResponse() + cont_resp2['error_context_data'] = b"\x05\x06\x07\x08" + error_resp['error_data'] = [ + cont_resp1, cont_resp2 + ] + header = self._get_header(error_resp) + try: + raise SMBResponseException(header, header['status'].get_value(), + header['message_id'].get_value()) + except SMBResponseException as exc: + assert len(exc.error_details) == 2 + assert exc.error_details[0] == b"\x01\x02\x03\x04" + assert exc.error_details[1] == b"\x05\x06\x07\x08" + exp_resp = "Received unexpected status from the server: " \ + "(3221225485) STATUS_INVALID_PARAMETER: 0xc000000d - " \ + "Raw: 01020304, Raw: 05060708" + assert exc.message == exp_resp + assert str(exc) == exp_resp + + assert exc.status == NtStatus.STATUS_INVALID_PARAMETER + + def _get_header(self, data, status=NtStatus.STATUS_INVALID_PARAMETER): + header = SMB2HeaderResponse() + header['status'] = status + header['message_id'] = 10 + header['data'] = data + return header + + +class TestSMB2ErrorResponse(object): + + def test_create_message_plain(self): + # This is a plain error response without the error context response + # data appended + message = SMB2ErrorResponse() + expected = b"\x09\x00" \ + b"\x00" \ + b"\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(actual) == 8 + assert actual == expected + + def test_create_message_with_context(self): + message = SMB2ErrorResponse() + error_context = SMB2ErrorContextResponse() + error_context['error_context_data'] = b"\x01\x02\x03\x04" + message['error_data'] = [error_context] + expected = b"\x09\x00" \ + b"\x01" \ + b"\x00" \ + b"\x0c\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual = message.pack() + assert len(message) == 20 + assert actual == expected + + def test_parse_message_plain(self): + actual = SMB2ErrorResponse() + data = b"\x09\x00" \ + b"\x00" \ + b"\x00" \ + b"\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 8 + assert data == b"" + assert actual['structure_size'].get_value() == 9 + assert actual['error_context_count'].get_value() == 0 + assert actual['reserved'].get_value() == 0 + assert actual['byte_count'].get_value() == 0 + assert actual['error_data'].get_value() == [] + + def test_parse_message_with_context(self): + actual = SMB2ErrorResponse() + data = b"\x09\x00" \ + b"\x01" \ + b"\x00" \ + b"\x0c\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" # just a fake bytes value for test + data = actual.unpack(data) + assert len(actual) == 20 + assert data == b"" + assert actual['structure_size'].get_value() == 9 + assert actual['error_context_count'].get_value() == 1 + assert actual['reserved'].get_value() == 0 + assert actual['byte_count'].get_value() == 12 + assert len(actual['error_data']) == 12 + error_data = actual['error_data'].get_value() + assert len(error_data) == 1 + assert error_data[0]['error_data_length'].get_value() == 4 + assert error_data[0]['error_id'].get_value() == \ + SymbolicLinkErrorFlags.SYMLINK_FLAG_ABSOLUTE + assert error_data[0]['error_context_data'].get_value() == \ + b"\x01\x02\x03\x04" + + +class TestSMB2ErrorContextResponse(object): + + def test_create_message(self): + message = SMB2ErrorContextResponse() + message['error_id'] = ErrorContextId.SMB2_ERROR_ID_SHARE_REDIRECT + message['error_context_data'] = b"\x01\x02\x03\x04" + expected = b"\x04\x00\x00\x00" \ + b"\x72\x64\x52\x53" \ + b"\x01\x02\x03\x04" + actual = message.pack() + assert len(message) == 12 + assert actual == expected + + def test_parse_message(self): + actual = SMB2ErrorContextResponse() + data = b"\x04\x00\x00\x00" \ + b"\x72\x64\x52\x53" \ + b"\x01\x02\x03\x04" + data = actual.unpack(data) + assert len(actual) == 12 + assert data == b"" + assert actual['error_data_length'].get_value() == 4 + assert actual['error_id'].get_value() == \ + ErrorContextId.SMB2_ERROR_ID_SHARE_REDIRECT + assert actual['error_context_data'].get_value() == b"\x01\x02\x03\x04" + + +class TestSMB2SymbolicLinkErrorResponse(object): + + def test_create_message(self): + message = SMB2SymbolicLinkErrorResponse() + message.set_name(r"C:\temp\folder", r"\??\C:\temp\folder") + expected = b"\x58\x00\x00\x00" \ + b"\x53\x59\x4d\x4c" \ + b"\x0c\x00\x00\xa0" \ + b"\x4c\x00" \ + b"\x00\x00" \ + b"\x1c\x00" \ + b"\x24\x00" \ + b"\x00\x00" \ + b"\x1c\x00" \ + b"\x00\x00\x00\x00" \ + b"\x43\x00\x3a\x00\x5c\x00\x74\x00" \ + b"\x65\x00\x6d\x00\x70\x00\x5c\x00" \ + b"\x66\x00\x6f\x00\x6c\x00\x64\x00" \ + b"\x65\x00\x72\x00" \ + b"\x5c\x00\x3f\x00\x3f\x00\x5c\x00" \ + b"\x43\x00\x3a\x00\x5c\x00\x74\x00" \ + b"\x65\x00\x6d\x00\x70\x00\x5c\x00" \ + b"\x66\x00\x6f\x00\x6c\x00\x64\x00" \ + b"\x65\x00\x72\x00" + actual = message.pack() + assert len(actual) == 92 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SymbolicLinkErrorResponse() + data = b"\x58\x00\x00\x00" \ + b"\x53\x59\x4d\x4c" \ + b"\x0c\x00\x00\xa0" \ + b"\x4c\x00" \ + b"\x00\x00" \ + b"\x1c\x00" \ + b"\x24\x00" \ + b"\x00\x00" \ + b"\x1c\x00" \ + b"\x00\x00\x00\x00" \ + b"\x43\x00\x3a\x00\x5c\x00\x74\x00" \ + b"\x65\x00\x6d\x00\x70\x00\x5c\x00" \ + b"\x66\x00\x6f\x00\x6c\x00\x64\x00" \ + b"\x65\x00\x72\x00" \ + b"\x5c\x00\x3f\x00\x3f\x00\x5c\x00" \ + b"\x43\x00\x3a\x00\x5c\x00\x74\x00" \ + b"\x65\x00\x6d\x00\x70\x00\x5c\x00" \ + b"\x66\x00\x6f\x00\x6c\x00\x64\x00" \ + b"\x65\x00\x72\x00" + data = actual.unpack(data) + assert len(actual) == 92 + assert data == b"" + assert actual['symlink_length'].get_value() == 88 + assert actual['symlink_error_tag'].get_value() == b"\x53\x59\x4d\x4c" + assert actual['reparse_tag'].get_value() == b"\x0c\x00\x00\xa0" + assert actual['reparse_data_length'].get_value() == 76 + assert actual['unparsed_path_length'].get_value() == 0 + assert actual['substitute_name_offset'].get_value() == 28 + assert actual['substitute_name_length'].get_value() == 36 + assert actual['print_name_offset'].get_value() == 0 + assert actual['print_name_length'].get_value() == 28 + assert actual['flags'].get_value() == \ + SymbolicLinkErrorFlags.SYMLINK_FLAG_ABSOLUTE + assert actual['path_buffer'].get_value() == \ + b"\x43\x00\x3a\x00\x5c\x00\x74\x00" \ + b"\x65\x00\x6d\x00\x70\x00\x5c\x00" \ + b"\x66\x00\x6f\x00\x6c\x00\x64\x00" \ + b"\x65\x00\x72\x00" \ + b"\x5c\x00\x3f\x00\x3f\x00\x5c\x00" \ + b"\x43\x00\x3a\x00\x5c\x00\x74\x00" \ + b"\x65\x00\x6d\x00\x70\x00\x5c\x00" \ + b"\x66\x00\x6f\x00\x6c\x00\x64\x00" \ + b"\x65\x00\x72\x00" + assert actual.get_print_name() == r"C:\temp\folder" + assert actual.get_substitute_name() == r"\??\C:\temp\folder" + + +class TestSMB2ShareRedirectErrorContext(object): + + def test_create_message(self): + message = SMB2ShareRedirectErrorContext() + ip1 = SMB2MoveDstIpAddrStructure() + ip1['type'] = IpAddrType.MOVE_DST_IPADDR_V4 + ip1.set_ipaddress("192.168.1.100") + ip2 = SMB2MoveDstIpAddrStructure() + ip2['type'] = IpAddrType.MOVE_DST_IPADDR_V6 + ip2.set_ipaddress("fe80:12ab:0000:0000:0000:0001:0002:0000") + message['ip_addr_move_list'] = [ + ip1, ip2 + ] + message['resource_name'] = b"\x01\x02\x03\x04" + expected = b"\x4c\x00\x00\x00" \ + b"\x03\x00\x00\x00" \ + b"\x48\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xc0\xa8\x01\x64" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x12\xab\x00\x00\x00\x00" \ + b"\x00\x00\x00\x01\x00\x02\x00\x00" \ + b"\x01\x02\x03\x04" + actual = message.pack() + assert len(message) == 76 + assert actual == expected + + def test_parse_message(self): + actual = SMB2ShareRedirectErrorContext() + data = b"\x4c\x00\x00\x00" \ + b"\x03\x00\x00\x00" \ + b"\x48\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xc0\xa8\x01\x64" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x12\xab\x00\x00\x00\x00" \ + b"\x00\x00\x00\x01\x00\x02\x00\x00" \ + b"\x01\x02\x03\x04" + data = actual.unpack(data) + assert len(actual) == 76 + assert data == b"" + assert actual['structure_size'].get_value() == 76 + assert actual['notification_type'].get_value() == 3 + assert actual['resource_name_offset'].get_value() == 72 + assert actual['resource_name_length'].get_value() == 4 + assert actual['flags'].get_value() == 0 + assert actual['target_type'].get_value() == 0 + assert actual['ip_addr_count'].get_value() == 2 + ip_addr = actual['ip_addr_move_list'].get_value() + assert isinstance(ip_addr, list) + assert len(ip_addr) == 2 + ip1 = ip_addr[0] + assert ip1['type'].get_value() == IpAddrType.MOVE_DST_IPADDR_V4 + assert ip1['reserved'].get_value() == 0 + assert ip1['ip_address'].get_value() == b"\xc0\xa8\x01\x64" + assert ip1['reserved2'].get_value() == b"\x00" * 12 + ip2 = ip_addr[1] + assert ip2['type'].get_value() == IpAddrType.MOVE_DST_IPADDR_V6 + assert ip2['reserved'].get_value() == 0 + assert ip2['ip_address'].get_value() == \ + b"\xfe\x80\x12\xab\x00\x00\x00\x00" \ + b"\x00\x00\x00\x01\x00\x02\x00\x00" + assert ip2['reserved2'].get_value() == b"" + assert actual['resource_name'].get_value() == b"\x01\x02\x03\x04" + + +class TestSMB2MoveDstIpAddrStructure(object): + + def test_create_message_v4(self): + message = SMB2MoveDstIpAddrStructure() + message['type'] = IpAddrType.MOVE_DST_IPADDR_V4 + message.set_ipaddress("192.168.1.100") + expected = b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xc0\xa8\x01\x64" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 24 + assert actual == expected + + def test_create_message_v6(self): + message = SMB2MoveDstIpAddrStructure() + message['type'] = IpAddrType.MOVE_DST_IPADDR_V6 + message.set_ipaddress("fe80:12ab:0000:0000:0000:0001:0002:0000") + expected = b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x12\xab\x00\x00\x00\x00" \ + b"\x00\x00\x00\x01\x00\x02\x00\x00" + actual = message.pack() + assert len(message) == 24 + assert actual == expected + + def test_fail_invalid_ipv6_address(self): + message = SMB2MoveDstIpAddrStructure() + message['type'] = IpAddrType.MOVE_DST_IPADDR_V6 + with pytest.raises(ValueError) as exc: + message.set_ipaddress("abc") + assert str(exc.value) == "When setting an IPv6 address, it must be " \ + "in the full form without concatenation" + + def test_parse_message_v4(self): + actual = SMB2MoveDstIpAddrStructure() + data = b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xc0\xa8\x01\x64" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 24 + assert data == b"" + assert actual['type'].get_value() == IpAddrType.MOVE_DST_IPADDR_V4 + assert actual['reserved'].get_value() == 0 + assert actual['ip_address'].get_value() == b"\xc0\xa8\x01\x64" + assert actual['reserved2'].get_value() == b"\x00" * 12 + assert actual.get_ipaddress() == "192.168.1.100" + + def test_parse_message_v6(self): + actual = SMB2MoveDstIpAddrStructure() + data = b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x12\xab\x00\x00\x00\x00" \ + b"\x00\x00\x00\x01\x00\x02\x00\x00" + data = actual.unpack(data) + assert len(actual) == 24 + assert data == b"" + assert actual['type'].get_value() == IpAddrType.MOVE_DST_IPADDR_V6 + assert actual['reserved'].get_value() == 0 + assert actual['ip_address'].get_value() == \ + b"\xfe\x80\x12\xab\x00\x00\x00\x00" \ + b"\x00\x00\x00\x01\x00\x02\x00\x00" + assert actual['reserved2'].get_value() == b"" + assert actual.get_ipaddress() == \ + "fe80:12ab:0000:0000:0000:0001:0002:0000" diff --git a/tests/test_ioctl.py b/tests/test_ioctl.py new file mode 100644 index 00000000..0d71d175 --- /dev/null +++ b/tests/test_ioctl.py @@ -0,0 +1,696 @@ +import uuid + +import pytest + +from smbprotocol.connection import Dialects +from smbprotocol.ioctl import CtlCode, HashRetrievalType, HashVersion, \ + IfCapability, IOCTLFlags, SMB2IOCTLRequest, SMB2IOCTLResponse, \ + SMB2NetworkInterfaceInfo, SMB2SrvCopyChunk, SMB2SrvCopyChunkCopy, \ + SMB2SrvCopyChunkResponse, SMB2SrvNetworkResiliencyRequest, \ + SMB2SrvReadHashRequest, SMB2SrvRequestResumeKey, SMB2SrvSnapshotArray, \ + SMB2ValidateNegotiateInfoRequest, SMB2ValidateNegotiateInfoResponse, \ + SockAddrFamily, SockAddrIn, SockAddrIn6, SockAddrStorage + + +class TestSMB2IOCTLRequest(object): + + def test_create_message(self): + message = SMB2IOCTLRequest() + message['ctl_code'] = CtlCode.FSCTL_VALIDATE_NEGOTIATE_INFO + message['file_id'] = b"\xff" * 16 + message['max_input_response'] = 12 + message['max_output_response'] = 12 + message['flags'] = IOCTLFlags.SMB2_0_IOCTL_IS_FSCTL + message['buffer'] = b"\x12\x13\x14\x15" + expected = b"\x39\x00" \ + b"\x00\x00" \ + b"\x04\x02\x14\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x78\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x0c\x00\x00\x00" \ + b"\x78\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0c\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x12\x13\x14\x15" + actual = message.pack() + assert len(message) == 60 + assert actual == expected + + def test_create_message_no_buffer(self): + message = SMB2IOCTLRequest() + message['ctl_code'] = CtlCode.FSCTL_VALIDATE_NEGOTIATE_INFO + message['file_id'] = b"\xff" * 16 + message['flags'] = IOCTLFlags.SMB2_0_IOCTL_IS_FSCTL + expected = b"\x39\x00" \ + b"\x00\x00" \ + b"\x04\x02\x14\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 56 + assert actual == expected + + def test_parse_message(self): + actual = SMB2IOCTLRequest() + data = b"\x39\x00" \ + b"\x00\x00" \ + b"\x04\x02\x14\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x78\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x0c\x00\x00\x00" \ + b"\x78\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0c\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x12\x13\x14\x15" + actual.unpack(data) + assert len(actual) == 60 + assert actual['structure_size'].get_value() == 57 + assert actual['reserved'].get_value() == 0 + assert actual['ctl_code'].get_value() == \ + CtlCode.FSCTL_VALIDATE_NEGOTIATE_INFO + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['input_offset'].get_value() == 120 + assert actual['input_count'].get_value() == 4 + assert actual['max_input_response'].get_value() == 12 + assert actual['output_offset'].get_value() == 120 + assert actual['output_count'].get_value() == 0 + assert actual['max_output_response'].get_value() == 12 + assert actual['flags'].get_value() == IOCTLFlags.SMB2_0_IOCTL_IS_FSCTL + assert actual['reserved2'].get_value() == 0 + assert actual['buffer'].get_value() == b"\x12\x13\x14\x15" + + +class TestSMB2SrvCopyChunkCopy(object): + + def test_create_message(self): + chunk1 = SMB2SrvCopyChunk() + chunk1['source_offset'] = 0 + chunk1['target_offset'] = 10 + chunk1['length'] = 10 + + chunk2 = SMB2SrvCopyChunk() + chunk2['source_offset'] = 10 + chunk2['target_offset'] = 20 + chunk2['length'] = 10 + + message = SMB2SrvCopyChunkCopy() + message['source_key'] = b"\x11" * 24 + message['chunks'] = [ + chunk1, + chunk2 + ] + + expected = b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 80 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SrvCopyChunkCopy() + data = b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 80 + assert actual['source_key'].get_value() == b"\x11" * 24 + assert actual['chunk_count'].get_value() == 2 + assert actual['reserved'].get_value() == 0 + assert len(actual['chunks'].get_value()) == 2 + chunk1 = actual['chunks'][0] + assert chunk1['source_offset'].get_value() == 0 + assert chunk1['target_offset'].get_value() == 10 + assert chunk1['length'].get_value() == 10 + assert chunk1['reserved'].get_value() == 0 + chunk2 = actual['chunks'][1] + assert chunk2['source_offset'].get_value() == 10 + assert chunk2['target_offset'].get_value() == 20 + assert chunk2['length'].get_value() == 10 + assert chunk2['reserved'].get_value() == 0 + + +class TestSMB2SrvCopyChunk(object): + + def test_create_message(self): + message = SMB2SrvCopyChunk() + message['source_offset'] = 1234 + message['target_offset'] = 5678 + message['length'] = 10 + expected = b"\xd2\x04\x00\x00\x00\x00\x00\x00" \ + b"\x2e\x16\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 24 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SrvCopyChunk() + data = b"\xd2\x04\x00\x00\x00\x00\x00\x00" \ + b"\x2e\x16\x00\x00\x00\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 24 + assert actual['source_offset'].get_value() == 1234 + assert actual['target_offset'].get_value() == 5678 + assert actual['length'].get_value() == 10 + assert actual['reserved'].get_value() == 0 + + +class TestSMB2SrcReadHashRequest(object): + + def test_create_message(self): + message = SMB2SrvReadHashRequest() + message['hash_version'] = HashVersion.SRV_HASH_VER_2 + message['hash_retrieval_type'] = \ + HashRetrievalType.SRV_HASH_RETRIEVE_FILE_BASED + message['length'] = 10 + message['offset'] = 10 + expected = b"\x01\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 24 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SrvReadHashRequest() + data = b"\x01\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 24 + assert actual['hash_type'].get_value() == 1 + assert actual['hash_version'].get_value() == HashVersion.SRV_HASH_VER_2 + assert actual['hash_retrieval_type'].get_value() == \ + HashRetrievalType.SRV_HASH_RETRIEVE_FILE_BASED + assert actual['length'].get_value() == 10 + assert actual['offset'].get_value() == 10 + + +class TestSMB2SrvNetworkResiliencyRequest(object): + + def test_create_message(self): + message = SMB2SrvNetworkResiliencyRequest() + message['timeout'] = 100 + expected = b"\x64\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 8 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SrvNetworkResiliencyRequest() + data = b"\x64\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 8 + assert actual['timeout'].get_value() == 100 + assert actual['reserved'].get_value() == 0 + + +class TestSMB2ValidateNegotiateInfoRequest(object): + + def test_create_message(self): + message = SMB2ValidateNegotiateInfoRequest() + message['capabilities'] = 8 + message['guid'] = b"\x11" * 16 + message['security_mode'] = 1 + message['dialect_count'] = 2 + message['dialects'] = [Dialects.SMB_2_0_2, Dialects.SMB_2_1_0] + expected = b"\x08\x00\x00\x00" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x01\x00" \ + b"\x02\x00" \ + b"\x02\x02\x10\x02" + actual = message.pack() + assert len(message) == 28 + assert actual == expected + + def test_parse_message(self): + actual = SMB2ValidateNegotiateInfoRequest() + data = b"\x08\x00\x00\x00" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x11\x11\x11\x11\x11\x11\x11\x11" \ + b"\x01\x00" \ + b"\x02\x00" \ + b"\x02\x02\x10\x02" + actual.unpack(data) + assert len(actual) == 28 + assert actual['capabilities'].get_value() == 8 + assert actual['guid'].get_value() == uuid.UUID(bytes=b"\x11" * 16) + assert actual['security_mode'].get_value() == 1 + assert actual['dialect_count'].get_value() == 2 + assert actual['dialects'][0] == 514 + assert actual['dialects'][1] == 528 + assert len(actual['dialects'].get_value()) == 2 + + +class TestSMB2IOCTLResponse(object): + + def test_create_message(self): + message = SMB2IOCTLResponse() + message['ctl_code'] = CtlCode.FSCTL_VALIDATE_NEGOTIATE_INFO + message['file_id'] = b"\xff" * 16 + message['input_offset'] = 0 + message['input_count'] = 0 + message['output_offset'] = 112 + message['output_count'] = 4 + message['flags'] = IOCTLFlags.SMB2_0_IOCTL_IS_FSCTL + message['buffer'] = b"\x20\x21\x22\x23" + expected = b"\x31\x00\x00\x00" \ + b"\x04\x02\x14\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x70\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x20\x21\x22\x23" + actual = message.pack() + assert len(message) == 52 + assert actual == expected + + def test_parse_message(self): + actual = SMB2IOCTLResponse() + data = b"\x31\x00\x00\x00" \ + b"\x04\x02\x14\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x70\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x20\x21\x22\x23" + actual.unpack(data) + assert len(actual) == 52 + assert actual['structure_size'].get_value() == 49 + assert actual['reserved'].get_value() == 0 + assert actual['ctl_code'].get_value() == \ + CtlCode.FSCTL_VALIDATE_NEGOTIATE_INFO + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['input_offset'].get_value() == 0 + assert actual['input_count'].get_value() == 0 + assert actual['output_offset'].get_value() == 112 + assert actual['output_count'].get_value() == 4 + assert actual['flags'].get_value() == IOCTLFlags.SMB2_0_IOCTL_IS_FSCTL + assert actual['reserved2'].get_value() == 0 + assert actual['buffer'].get_value() == b"\x20\x21\x22\x23" + + +class TestSMB2SrvCopyChunkResponse(object): + + def test_create_message(self): + message = SMB2SrvCopyChunkResponse() + message['chunks_written'] = 2 + message['chunk_bytes_written'] = 10 + message['total_bytes_written'] = 10 + expected = b"\x02\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x0a\x00\x00\x00" + actual = message.pack() + assert len(message) == 12 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SrvCopyChunkResponse() + data = b"\x02\x00\x00\x00" \ + b"\x0a\x00\x00\x00" \ + b"\x0a\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 12 + assert actual['chunks_written'].get_value() == 2 + assert actual['chunk_bytes_written'].get_value() == 10 + assert actual['total_bytes_written'].get_value() == 10 + + +class TestSMB2SrvSnapshotArray(object): + + def test_create_message(self): + message = SMB2SrvSnapshotArray() + message['snapshot_array_size'] = 2 + message['snapshots'] = b"\x00\x00\x00\x00" + expected = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 16 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SrvSnapshotArray() + data = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 16 + assert actual['number_of_snapshots'].get_value() == 0 + assert actual['number_of_snapshots_returned'].get_value() == 0 + assert actual['snapshot_array_size'].get_value() == 2 + assert actual['snapshots'].get_value() == b"\x00\x00\x00\x00" + + +class TestSMB2SrvRequestResumeKey(object): + + def test_create_message(self): + message = SMB2SrvRequestResumeKey() + message['resume_key'] = b"\xff" * 24 + expected = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 28 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SrvRequestResumeKey() + data = b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 28 + assert actual['resume_key'].get_value() == b"\xff" * 24 + assert actual['context_length'].get_value() == 0 + + +class TestSMB2NetworkInterfaceInfo(object): + + def test_create_message(self): + addr1 = SockAddrIn() + addr1.set_ipaddress("10.0.2.15") + sock_addr1 = SockAddrStorage() + sock_addr1['family'] = SockAddrFamily.INTER_NETWORK + sock_addr1['buffer'] = addr1 + msg1 = SMB2NetworkInterfaceInfo() + msg1['if_index'] = 2 + msg1['link_speed'] = 1000000000 + msg1['sock_addr_storage'] = sock_addr1 + + addr2 = SockAddrIn6() + addr2.set_ipaddress("fe80:0000:0000:0000:894a:2dbc:1d9c:2da1") + sock_addr2 = SockAddrStorage() + sock_addr2['family'] = SockAddrFamily.INTER_NETWORK_V6 + sock_addr2['buffer'] = addr2 + msg2 = SMB2NetworkInterfaceInfo() + msg2['if_index'] = 4 + msg2['capability'].set_flag(IfCapability.RSS_CAPABLE) + msg2['link_speed'] = 1000000 + msg2['sock_addr_storage'] = sock_addr2 + + expected = b"\x98\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\xca\x9a\x3b\x00\x00\x00\x00" \ + b"\x02\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x02\x0f" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + expected += b"\x00" * 112 + expected += b"\x00\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x40\x42\x0f\x00\x00\x00\x00\x00" \ + b"\x17\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x00\x00\x00\x00\x00\x00" \ + b"\x89\x4a\x2d\xbc\x1d\x9c\x2d\xa1" \ + b"\x00\x00\x00\x00" + expected += b"\x00" * 100 + actual = SMB2NetworkInterfaceInfo.pack_multiple([msg1, msg2]) + assert len(msg1) == 152 + assert len(msg2) == 152 + assert actual == expected + + def test_parse_message(self): + data = b"\x98\x00\x00\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\xca\x9a\x3b\x00\x00\x00\x00" \ + b"\x02\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x02\x0f" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + data += b"\x00" * 112 + data += b"\x00\x00\x00\x00" \ + b"\x04\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x40\x42\x0f\x00\x00\x00\x00\x00" \ + b"\x17\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x00\x00\x00\x00\x00\x00" \ + b"\x89\x4a\x2d\xbc\x1d\x9c\x2d\xa1" \ + b"\x00\x00\x00\x00" + data += b"\x00" * 100 + actual = SMB2NetworkInterfaceInfo.unpack_multiple(data) + assert len(actual) == 2 + assert len(actual[0]) == 152 + assert len(actual[1]) == 152 + + assert actual[0]['next'].get_value() == 152 + assert actual[0]['if_index'].get_value() == 2 + assert actual[0]['capability'].get_value() == 0 + assert actual[0]['reserved'].get_value() == 0 + assert actual[0]['link_speed'].get_value() == 1000000000 + actual_sock1 = actual[0]['sock_addr_storage'].get_value() + assert actual_sock1['family'].get_value() == \ + SockAddrFamily.INTER_NETWORK + + assert actual[1]['next'].get_value() == 0 + assert actual[1]['if_index'].get_value() == 4 + assert actual[1]['capability'].get_value() == IfCapability.RSS_CAPABLE + assert actual[1]['reserved'].get_value() == 0 + assert actual[1]['link_speed'].get_value() == 1000000 + actual_sock2 = actual[1]['sock_addr_storage'].get_value() + assert actual_sock2['family'].get_value() == \ + SockAddrFamily.INTER_NETWORK_V6 + + +class TestSockAddrStorage(object): + + def test_create_message_ipv4(self): + message = SockAddrStorage() + message['family'] = SockAddrFamily.INTER_NETWORK + sock_addr = SockAddrIn() + sock_addr.set_ipaddress("10.0.2.15") + message['buffer'] = sock_addr + expected = b"\x02\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x02\x0f" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + expected += b"\x00" * 112 + actual = message.pack() + assert len(message) == 128 + assert actual == expected + + def test_create_message_ipv6(self): + message = SockAddrStorage() + message['family'] = SockAddrFamily.INTER_NETWORK_V6 + sock_addr = SockAddrIn6() + sock_addr.set_ipaddress("fe80:0000:0000:0000:894a:2dbc:1d9c:2da1") + message['buffer'] = sock_addr + expected = b"\x17\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x00\x00\x00\x00\x00\x00" \ + b"\x89\x4a\x2d\xbc\x1d\x9c\x2d\xa1" \ + b"\x00\x00\x00\x00" + expected += b"\x00" * 100 + actual = message.pack() + assert len(message) == 128 + assert actual == expected + + def test_parse_message_ipv4(self): + actual = SockAddrStorage() + data = b"\x02\x00" \ + b"\x00\x00" \ + b"\x0a\x00\x02\x0f" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + data += b"\x00" * 112 + actual.unpack(data) + assert len(actual) == 128 + assert actual['family'].get_value() == SockAddrFamily.INTER_NETWORK + sock_addr = actual['buffer'].get_value() + assert isinstance(sock_addr, SockAddrIn) + assert sock_addr.get_ipaddress() == \ + "10.0.2.15" + + def test_parse_message_ipv6(self): + actual = SockAddrStorage() + data = b"\x17\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x00\x00\x00\x00\x00\x00" \ + b"\x89\x4a\x2d\xbc\x1d\x9c\x2d\xa1" \ + b"\x00\x00\x00\x00" + data += b"\x00" * 100 + actual.unpack(data) + assert len(actual) == 128 + assert actual['family'].get_value() == SockAddrFamily.INTER_NETWORK_V6 + sock_addr = actual['buffer'].get_value() + assert isinstance(sock_addr, SockAddrIn6) + assert sock_addr.get_ipaddress() == \ + "fe80:0000:0000:0000:894a:2dbc:1d9c:2da1" + + +class TestSockAddrIn(object): + + def test_create_message(self): + message = SockAddrIn() + message.set_ipaddress("10.0.2.15") + expected = b"\x00\x00" \ + b"\x0a\x00\x02\x0f" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 14 + assert actual == expected + + def test_create_message_subnet(self): + message = SockAddrIn() + message.set_ipaddress("255.255.255.255") + expected = b"\x00\x00" \ + b"\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 14 + assert actual == expected + + def test_parse_message(self): + actual = SockAddrIn() + data = b"\x00\x00" \ + b"\x0a\x00\x02\x0f" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 14 + assert actual['port'].get_value() == 0 + assert actual['ipv4_address'].get_value() == b"\x0a\x00\x02\x0f" + assert actual['reserved'].get_value() == 0 + assert actual.get_ipaddress() == "10.0.2.15" + + +class TestSockAddrIn6(object): + + def test_create_message(self): + message = SockAddrIn6() + message.set_ipaddress("fe80:0000:0000:0000:894a:2dbc:1d9c:2da1") + expected = b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x00\x00\x00\x00\x00\x00" \ + b"\x89\x4a\x2d\xbc\x1d\x9c\x2d\xa1" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 26 + assert actual == expected + + def test_set_ipaddress_invalid_format(self): + message = SockAddrIn6() + with pytest.raises(ValueError) as exc: + message.set_ipaddress("fe80::894a:2dbc:1d9c:2da1") + assert str(exc.value) == "When setting an IPv6 address, it must be " \ + "in the full form without concatenation" + + def test_parse_message(self): + actual = SockAddrIn6() + data = b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xfe\x80\x00\x00\x00\x00\x00\x00" \ + b"\x89\x4a\x2d\xbc\x1d\x9c\x2d\xa1" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 26 + assert actual['port'].get_value() == 0 + assert actual['flow_info'].get_value() == 0 + assert actual['ipv6_address'].get_value() == \ + b"\xfe\x80\x00\x00\x00\x00\x00\x00" \ + b"\x89\x4a\x2d\xbc\x1d\x9c\x2d\xa1" + assert actual['scope_id'].get_value() == 0 + assert actual.get_ipaddress() == \ + "fe80:0000:0000:0000:894a:2dbc:1d9c:2da1" + + +class TestSMB2ValidateNegotiateInfoResponse(object): + + def test_create_message(self): + message = SMB2ValidateNegotiateInfoResponse() + message['capabilities'] = 8 + message['guid'] = b"\xff" * 16 + message['security_mode'] = 0 + message['dialect'] = Dialects.SMB_3_0_2 + expected = b"\x08\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00" \ + b"\x02\x03" + actual = message.pack() + assert len(message) == 24 + assert actual == expected + + def test_parse_message(self): + actual = SMB2ValidateNegotiateInfoResponse() + data = b"\x08\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00" \ + b"\x02\x03" + actual.unpack(data) + assert len(actual) == 24 + assert actual['capabilities'].get_value() == 8 + assert actual['guid'].get_value() == uuid.UUID(bytes=b"\xff" * 16) + assert actual['security_mode'].get_value() == 0 + assert actual['dialect'].get_value() == Dialects.SMB_3_0_2 diff --git a/tests/test_open.py b/tests/test_open.py new file mode 100644 index 00000000..4b7bc475 --- /dev/null +++ b/tests/test_open.py @@ -0,0 +1,1887 @@ +import os +import uuid + +from datetime import datetime + +import pytest + +from smbprotocol.connection import Connection, Dialects +from smbprotocol.exceptions import SMBException +from smbprotocol.session import Session +from smbprotocol.tree import TreeConnect +from smbprotocol.open import CloseFlags, CreateAction, CreateDisposition, \ + CreateOptions, DirectoryAccessMask, FileAttributes, FileInformationClass, \ + FileFlags, FilePipePrinterAccessMask, ImpersonationLevel, \ + ReadWriteChannel, ShareAccess, SMB2CloseRequest, SMB2CloseResponse, \ + SMB2CreateRequest, SMB2CreateResponse, SMB2FlushRequest, \ + SMB2FlushResponse, SMB2QueryDirectoryRequest, SMB2QueryDirectoryResponse, \ + SMB2ReadRequest, SMB2ReadResponse, SMB2WriteRequest, SMB2WriteResponse, \ + Open +from smbprotocol.query_info import FileNamesInformation +from smbprotocol.create_contexts import CreateContextName, \ + SMB2CreateAllocationSize, SMB2CreateContextRequest, \ + SMB2CreateQueryMaximalAccessRequest, \ + SMB2CreateQueryMaximalAccessResponse, SMB2CreateQueryOnDiskIDResponse, \ + SMB2CreateTimewarpToken +from smbprotocol.exceptions import SMBUnsupportedFeature + +from .utils import smb_real + + +class TestSMB2CreateRequest(object): + + def test_create_message(self): + timewarp_token = SMB2CreateTimewarpToken() + timewarp_token['timestamp'] = datetime.utcfromtimestamp(0) + timewarp_context = SMB2CreateContextRequest() + timewarp_context['buffer_name'] = \ + CreateContextName.SMB2_CREATE_TIMEWARP_TOKEN + timewarp_context['buffer_data'] = timewarp_token + + message = SMB2CreateRequest() + message['impersonation_level'] = ImpersonationLevel.Impersonation + message['desired_access'] = FilePipePrinterAccessMask.GENERIC_READ + message['file_attributes'] = FileAttributes.FILE_ATTRIBUTE_NORMAL + message['share_access'] = ShareAccess.FILE_SHARE_READ + message['create_disposition'] = CreateDisposition.FILE_OPEN + message['create_options'] = CreateOptions.FILE_NON_DIRECTORY_FILE + message['buffer_path'] = r"\\server\share".encode("utf-16-le") + message['buffer_contexts'] = [timewarp_context] + expected = b"\x39\x00" \ + b"\x00" \ + b"\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x80" \ + b"\x80\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x40\x00\x00\x00" \ + b"\x78\x00" \ + b"\x1c\x00" \ + b"\x98\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x5c\x00\x5c\x00\x73\x00\x65\x00" \ + b"\x72\x00\x76\x00\x65\x00\x72\x00" \ + b"\x5C\x00\x73\x00\x68\x00\x61\x00" \ + b"\x72\x00\x65\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x08\x00\x00\x00" \ + b"\x54\x57\x72\x70" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + actual = message.pack() + assert len(message) == 120 + assert actual == expected + + def test_create_message_no_contexts(self): + message = SMB2CreateRequest() + message['impersonation_level'] = ImpersonationLevel.Impersonation + message['desired_access'] = FilePipePrinterAccessMask.GENERIC_READ + message['file_attributes'] = FileAttributes.FILE_ATTRIBUTE_NORMAL + message['share_access'] = ShareAccess.FILE_SHARE_READ + message['create_disposition'] = CreateDisposition.FILE_OPEN + message['create_options'] = CreateOptions.FILE_NON_DIRECTORY_FILE + message['buffer_path'] = r"\\server\share".encode("utf-16-le") + expected = b"\x39\x00" \ + b"\x00" \ + b"\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x80" \ + b"\x80\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x40\x00\x00\x00" \ + b"\x78\x00" \ + b"\x1c\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x5c\x00\x5c\x00\x73\x00\x65\x00" \ + b"\x72\x00\x76\x00\x65\x00\x72\x00" \ + b"\x5C\x00\x73\x00\x68\x00\x61\x00" \ + b"\x72\x00\x65\x00" + actual = message.pack() + assert len(message) == 84 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateRequest() + data = b"\x39\x00" \ + b"\x00" \ + b"\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x80" \ + b"\x80\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x40\x00\x00\x00" \ + b"\x78\x00" \ + b"\x1c\x00" \ + b"\x98\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x5c\x00\x5c\x00\x73\x00\x65\x00" \ + b"\x72\x00\x76\x00\x65\x00\x72\x00" \ + b"\x5C\x00\x73\x00\x68\x00\x61\x00" \ + b"\x72\x00\x65\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x08\x00\x00\x00" \ + b"\x54\x57\x72\x70" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + data = actual.unpack(data) + assert len(actual) == 120 + assert data == b"" + assert actual['structure_size'].get_value() == 57 + assert actual['security_flags'].get_value() == 0 + assert actual['requested_oplock_level'].get_value() == 0 + assert actual['impersonation_level'].get_value() == \ + ImpersonationLevel.Impersonation + assert actual['smb_create_flags'].get_value() == 0 + assert actual['reserved'].get_value() == 0 + assert actual['desired_access'].get_value() == \ + FilePipePrinterAccessMask.GENERIC_READ + assert actual['file_attributes'].get_value() == \ + FileAttributes.FILE_ATTRIBUTE_NORMAL + assert actual['share_access'].get_value() == \ + ShareAccess.FILE_SHARE_READ + assert actual['create_disposition'].get_value() == \ + CreateDisposition.FILE_OPEN + assert actual['create_options'].get_value() == \ + CreateOptions.FILE_NON_DIRECTORY_FILE + assert actual['name_offset'].get_value() == 120 + assert actual['name_length'].get_value() == 28 + assert actual['create_contexts_offset'].get_value() == 152 + assert actual['create_contexts_length'].get_value() == 32 + assert actual['buffer_path'].get_value() == \ + r"\\server\share".encode("utf-16-le") + assert actual['padding'].get_value() == b"\x00\x00\x00\x00" + + contexts = actual['buffer_contexts'].get_value() + assert isinstance(contexts, list) + timewarp_context = contexts[0] + assert timewarp_context['next'].get_value() == 0 + assert timewarp_context['name_offset'].get_value() == 16 + assert timewarp_context['name_length'].get_value() == 4 + assert timewarp_context['reserved'].get_value() == 0 + assert timewarp_context['data_offset'].get_value() == 24 + assert timewarp_context['data_length'].get_value() == 8 + assert timewarp_context['buffer_name'].get_value() == \ + CreateContextName.SMB2_CREATE_TIMEWARP_TOKEN + assert timewarp_context['padding'].get_value() == b"\x00\x00\x00\x00" + assert timewarp_context['buffer_data'].get_value() == \ + b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + assert timewarp_context['padding2'].get_value() == b"" + + def test_parse_message_no_contexts(self): + actual = SMB2CreateRequest() + data = b"\x39\x00" \ + b"\x00" \ + b"\x00" \ + b"\x02\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x80" \ + b"\x80\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x40\x00\x00\x00" \ + b"\x78\x00" \ + b"\x1c\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x5c\x00\x5c\x00\x73\x00\x65\x00" \ + b"\x72\x00\x76\x00\x65\x00\x72\x00" \ + b"\x5C\x00\x73\x00\x68\x00\x61\x00" \ + b"\x72\x00\x65\x00" \ + + data = actual.unpack(data) + assert len(actual) == 84 + assert data == b"" + assert actual['structure_size'].get_value() == 57 + assert actual['security_flags'].get_value() == 0 + assert actual['requested_oplock_level'].get_value() == 0 + assert actual['impersonation_level'].get_value() == \ + ImpersonationLevel.Impersonation + assert actual['smb_create_flags'].get_value() == 0 + assert actual['reserved'].get_value() == 0 + assert actual['desired_access'].get_value() == \ + FilePipePrinterAccessMask.GENERIC_READ + assert actual['file_attributes'].get_value() == \ + FileAttributes.FILE_ATTRIBUTE_NORMAL + assert actual['share_access'].get_value() == \ + ShareAccess.FILE_SHARE_READ + assert actual['create_disposition'].get_value() == \ + CreateDisposition.FILE_OPEN + assert actual['create_options'].get_value() == \ + CreateOptions.FILE_NON_DIRECTORY_FILE + assert actual['name_offset'].get_value() == 120 + assert actual['name_length'].get_value() == 28 + assert actual['create_contexts_offset'].get_value() == 0 + assert actual['create_contexts_length'].get_value() == 0 + assert actual['buffer_path'].get_value() == \ + r"\\server\share".encode("utf-16-le") + assert actual['padding'].get_value() == b"" + assert actual['buffer_contexts'].get_value() == [] + + +class TestSMB2CreateResponse(object): + + def test_create_message(self): + message = SMB2CreateResponse() + message['flag'] = FileFlags.SMB2_CREATE_FLAG_REPARSEPOINT + message['create_action'] = CreateAction.FILE_CREATED + message['creation_time'] = datetime.utcfromtimestamp(1024) + message['last_access_time'] = datetime.utcfromtimestamp(2048) + message['last_write_time'] = datetime.utcfromtimestamp(3072) + message['change_time'] = datetime.utcfromtimestamp(4096) + message['allocation_size'] = 10 + message['end_of_file'] = 20 + message['file_attributes'] = FileAttributes.FILE_ATTRIBUTE_ARCHIVE + message['file_id'] = b"\xff" * 16 + + timewarp_token = SMB2CreateTimewarpToken() + timewarp_token['timestamp'] = datetime.utcfromtimestamp(0) + timewarp_context = SMB2CreateContextRequest() + timewarp_context['buffer_name'] = \ + CreateContextName.SMB2_CREATE_TIMEWARP_TOKEN + timewarp_context['buffer_data'] = timewarp_token + message['buffer'] = [timewarp_context] + expected = b"\x59\x00" \ + b"\x00" \ + b"\x01" \ + b"\x02\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\xf2\x99\xe3\xb1\x9d\x01" \ + b"\x00\x80\x4c\xfc\xe5\xb1\x9d\x01" \ + b"\x00\x80\xa6\x5e\xe8\xb1\x9d\x01" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x98\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x08\x00\x00\x00" \ + b"\x54\x57\x72\x70" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + actual = message.pack() + assert len(message) == 120 + assert actual == expected + + def test_create_message_no_contexts(self): + message = SMB2CreateResponse() + message['flag'] = FileFlags.SMB2_CREATE_FLAG_REPARSEPOINT + message['create_action'] = CreateAction.FILE_CREATED + message['creation_time'] = datetime.utcfromtimestamp(1024) + message['last_access_time'] = datetime.utcfromtimestamp(2048) + message['last_write_time'] = datetime.utcfromtimestamp(3072) + message['change_time'] = datetime.utcfromtimestamp(4096) + message['allocation_size'] = 10 + message['end_of_file'] = 20 + message['file_attributes'] = FileAttributes.FILE_ATTRIBUTE_ARCHIVE + message['file_id'] = b"\xff" * 16 + expected = b"\x59\x00" \ + b"\x00" \ + b"\x01" \ + b"\x02\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\xf2\x99\xe3\xb1\x9d\x01" \ + b"\x00\x80\x4c\xfc\xe5\xb1\x9d\x01" \ + b"\x00\x80\xa6\x5e\xe8\xb1\x9d\x01" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 88 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CreateResponse() + data = b"\x59\x00" \ + b"\x00" \ + b"\x01" \ + b"\x02\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\xf2\x99\xe3\xb1\x9d\x01" \ + b"\x00\x80\x4c\xfc\xe5\xb1\x9d\x01" \ + b"\x00\x80\xa6\x5e\xe8\xb1\x9d\x01" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x98\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x10\x00" \ + b"\x04\x00" \ + b"\x00\x00" \ + b"\x18\x00" \ + b"\x08\x00\x00\x00" \ + b"\x54\x57\x72\x70" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + data = actual.unpack(data) + assert len(actual) == 120 + assert data == b"" + assert actual['structure_size'].get_value() == 89 + assert actual['oplock_level'].get_value() == 0 + assert actual['flag'].get_value() == \ + FileFlags.SMB2_CREATE_FLAG_REPARSEPOINT + assert actual['create_action'].get_value() == CreateAction.FILE_CREATED + assert actual['creation_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_access_time'].get_value() == \ + datetime.utcfromtimestamp(2048) + assert actual['last_write_time'].get_value() == \ + datetime.utcfromtimestamp(3072) + assert actual['change_time'].get_value() == \ + datetime.utcfromtimestamp(4096) + assert actual['allocation_size'].get_value() == 10 + assert actual['end_of_file'].get_value() == 20 + assert actual['file_attributes'].get_value() == \ + FileAttributes.FILE_ATTRIBUTE_ARCHIVE + assert actual['reserved2'].get_value() == 0 + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['create_contexts_offset'].get_value() == 152 + assert actual['create_contexts_length'].get_value() == 32 + + contexts = actual['buffer'].get_value() + assert isinstance(contexts, list) + timewarp_context = contexts[0] + assert timewarp_context['next'].get_value() == 0 + assert timewarp_context['name_offset'].get_value() == 16 + assert timewarp_context['name_length'].get_value() == 4 + assert timewarp_context['reserved'].get_value() == 0 + assert timewarp_context['data_offset'].get_value() == 24 + assert timewarp_context['data_length'].get_value() == 8 + assert timewarp_context['buffer_name'].get_value() == \ + CreateContextName.SMB2_CREATE_TIMEWARP_TOKEN + assert timewarp_context['padding'].get_value() == b"\x00\x00\x00\x00" + assert timewarp_context['buffer_data'].get_value() == \ + b"\x00\x80\x3e\xd5\xde\xb1\x9d\x01" + assert timewarp_context['padding2'].get_value() == b"" + + def test_parse_message_no_contexts(self): + actual = SMB2CreateResponse() + data = b"\x59\x00" \ + b"\x00" \ + b"\x01" \ + b"\x02\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\xf2\x99\xe3\xb1\x9d\x01" \ + b"\x00\x80\x4c\xfc\xe5\xb1\x9d\x01" \ + b"\x00\x80\xa6\x5e\xe8\xb1\x9d\x01" \ + b"\x0a\x00\x00\x00\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 88 + assert data == b"" + assert actual['structure_size'].get_value() == 89 + assert actual['oplock_level'].get_value() == 0 + assert actual['flag'].get_value() == \ + FileFlags.SMB2_CREATE_FLAG_REPARSEPOINT + assert actual['create_action'].get_value() == CreateAction.FILE_CREATED + assert actual['creation_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_access_time'].get_value() == \ + datetime.utcfromtimestamp(2048) + assert actual['last_write_time'].get_value() == \ + datetime.utcfromtimestamp(3072) + assert actual['change_time'].get_value() == \ + datetime.utcfromtimestamp(4096) + assert actual['allocation_size'].get_value() == 10 + assert actual['end_of_file'].get_value() == 20 + assert actual['file_attributes'].get_value() == \ + FileAttributes.FILE_ATTRIBUTE_ARCHIVE + assert actual['reserved2'].get_value() == 0 + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['create_contexts_offset'].get_value() == 0 + assert actual['create_contexts_length'].get_value() == 0 + assert actual['buffer'].get_value() == [] + + +class TestSMB2CloseRequest(object): + + def test_create_message(self): + message = SMB2CloseRequest() + message['flags'].set_flag(CloseFlags.SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB) + message['file_id'] = b"\xff" * 16 + expected = b"\x18\x00" \ + b"\x01\x00" \ + b"\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" + actual = message.pack() + assert len(actual) == 24 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CloseRequest() + data = b"\x18\x00" \ + b"\x01\x00" \ + b"\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" + actual.unpack(data) + assert len(actual) == 24 + assert actual['structure_size'].get_value() == 24 + assert actual['flags'].get_value() == \ + CloseFlags.SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB + assert actual['reserved'].get_value() == 0 + assert actual['file_id'].get_value() == b"\xff" * 16 + + +class TestSMB2CloseResponse(object): + + def test_create_message(self): + message = SMB2CloseResponse() + message['creation_time'] = datetime.utcfromtimestamp(0) + message['last_access_time'] = datetime.utcfromtimestamp(0) + message['last_write_time'] = datetime.utcfromtimestamp(0) + message['change_time'] = datetime.utcfromtimestamp(0) + expected = b"\x3c\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x3E\xD5\xDE\xB1\x9D\x01" \ + b"\x00\x80\x3E\xD5\xDE\xB1\x9D\x01" \ + b"\x00\x80\x3E\xD5\xDE\xB1\x9D\x01" \ + b"\x00\x80\x3E\xD5\xDE\xB1\x9D\x01" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(actual) == 60 + assert actual == expected + + def test_parse_message(self): + actual = SMB2CloseResponse() + data = b"\x3c\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x3E\xD5\xDE\xB1\x9D\x01" \ + b"\x00\x80\x3E\xD5\xDE\xB1\x9D\x01" \ + b"\x00\x80\x3E\xD5\xDE\xB1\x9D\x01" \ + b"\x00\x80\x3E\xD5\xDE\xB1\x9D\x01" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 60 + assert actual['structure_size'].get_value() == 60 + assert actual['flags'].get_value() == 0 + assert actual['reserved'].get_value() == 0 + assert actual['creation_time'].get_value() == \ + datetime.utcfromtimestamp(0) + assert actual['last_access_time'].get_value() == \ + datetime.utcfromtimestamp(0) + assert actual['last_write_time'].get_value() == \ + datetime.utcfromtimestamp(0) + assert actual['change_time'].get_value() == \ + datetime.utcfromtimestamp(0) + assert actual['allocation_size'].get_value() == 0 + assert actual['end_of_file'].get_value() == 0 + assert actual['file_attributes'].get_value() == 0 + + +class TestSMB2FlushRequest(object): + + def test_create_message(self): + message = SMB2FlushRequest() + message['file_id'] = b"\xff" * 16 + expected = b"\x18\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" + actual = message.pack() + assert len(message) == 24 + assert actual == expected + + def test_parse_message(self): + actual = SMB2FlushRequest() + data = b"\x18\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" + actual.unpack(data) + assert len(actual) == 24 + assert actual['structure_size'].get_value() == 24 + assert actual['reserved1'].get_value() == 0 + assert actual['reserved2'].get_value() == 0 + assert actual['file_id'].pack() == b"\xff" * 16 + + +class TestSMB2FlushResponse(object): + + def test_create_message(self): + message = SMB2FlushResponse() + expected = b"\x04\x00" \ + b"\x00\x00" + actual = message.pack() + assert len(message) == 4 + assert actual == expected + + def test_parse_message(self): + actual = SMB2FlushResponse() + data = b"\x04\x00" \ + b"\x00\x00" + actual.unpack(data) + assert len(actual) == 4 + assert actual['structure_size'].get_value() == 4 + assert actual['reserved'].get_value() == 0 + + +class TestSMB2ReadRequest(object): + + def test_create_message(self): + message = SMB2ReadRequest() + message['padding'] = b"\x50" + message['length'] = 1024 + message['offset'] = 0 + message['file_id'] = b"\xff" * 16 + message['remaining_bytes'] = 0 + expected = b"\x31\x00" \ + b"\x50" \ + b"\x00" \ + b"\x00\x04\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x00" + actual = message.pack() + assert len(message) == 49 + assert actual == expected + + def test_create_message_channel_info(self): + message = SMB2ReadRequest() + message['padding'] = b"\x50" + message['length'] = 1024 + message['offset'] = 0 + message['file_id'] = b"\xff" * 16 + message['channel'].set_flag(ReadWriteChannel.SMB2_CHANNEL_RDMA_V1) + message['remaining_bytes'] = 0 + message['buffer'] = b"\x00" * 16 + expected = b"\x31\x00" \ + b"\x50" \ + b"\x00" \ + b"\x00\x04\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x70\x00" \ + b"\x10\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 64 + assert actual == expected + + def test_parse_message(self): + actual = SMB2ReadRequest() + data = b"\x31\x00" \ + b"\x50" \ + b"\x00" \ + b"\x00\x04\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x00" + actual.unpack(data) + assert len(actual) == 49 + assert actual['structure_size'].get_value() == 49 + assert actual['padding'].get_value() == 80 + assert actual['flags'].get_value() == 0 + assert actual['length'].get_value() == 1024 + assert actual['offset'].get_value() == 0 + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['minimum_count'].get_value() == 0 + assert actual['channel'].get_value() == 0 + assert actual['remaining_bytes'].get_value() == 0 + assert actual['read_channel_info_offset'].get_value() == 0 + assert actual['read_channel_info_length'].get_value() == 0 + assert actual['buffer'].get_value() == b"\x00" + + def test_parse_message_channel_info(self): + actual = SMB2ReadRequest() + data = b"\x31\x00" \ + b"\x50" \ + b"\x00" \ + b"\x00\x04\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x70\x00" \ + b"\x10\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 64 + assert actual['structure_size'].get_value() == 49 + assert actual['padding'].get_value() == 80 + assert actual['flags'].get_value() == 0 + assert actual['length'].get_value() == 1024 + assert actual['offset'].get_value() == 0 + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['minimum_count'].get_value() == 0 + assert actual['channel'].get_value() == \ + ReadWriteChannel.SMB2_CHANNEL_RDMA_V1 + assert actual['remaining_bytes'].get_value() == 0 + assert actual['read_channel_info_offset'].get_value() == 112 + assert actual['read_channel_info_length'].get_value() == 16 + assert actual['buffer'].get_value() == b"\x00" * 16 + + +class TestSMB2ReadResponse(object): + + def test_create_message(self): + message = SMB2ReadResponse() + message['data_offset'] = 80 + message['data_length'] = 4 + message['buffer'] = b"\x01\x02\x03\x04" + expected = b"\x11\x00" \ + b"\x50" \ + b"\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual = message.pack() + assert len(message) == 20 + assert actual == expected + + def test_parse_message(self): + actual = SMB2ReadResponse() + data = b"\x11\x00" \ + b"\x50" \ + b"\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual.unpack(data) + assert len(actual) == 20 + assert actual['structure_size'].get_value() == 17 + assert actual['data_offset'].get_value() == 80 + assert actual['reserved'].get_value() == 0 + assert actual['data_length'].get_value() == 4 + assert actual['data_remaining'].get_value() == 0 + assert actual['reserved2'].get_value() == 0 + assert actual['buffer'].get_value() == b"\x01\x02\x03\x04" + + +class TestSMB2WriteRequest(object): + + def test_create_message(self): + message = SMB2WriteRequest() + message['offset'] = 131072 + message['file_id'] = b"\xff" * 16 + message['channel'].set_flag(ReadWriteChannel.SMB2_CHANNEL_NONE) + message['remaining_bytes'] = 0 + message['buffer'] = b"\x01\x02\x03\x04" + expected = b"\x31\x00" \ + b"\x70\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00\x02\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual = message.pack() + assert len(message) == 52 + assert actual == expected + + def test_create_message_channel_info(self): + message = SMB2WriteRequest() + message['offset'] = 131072 + message['file_id'] = b"\xff" * 16 + message['channel'].set_flag(ReadWriteChannel.SMB2_CHANNEL_RDMA_V1) + message['remaining_bytes'] = 0 + message['buffer'] = b"\x01\x02\x03\x04" + message['buffer_channel_info'] = b"\x00" * 16 + expected = b"\x31\x00" \ + b"\x70\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00\x02\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x01\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x74\x00" \ + b"\x10\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 68 + assert actual == expected + + def test_parse_message(self): + actual = SMB2WriteRequest() + data = b"\x31\x00" \ + b"\x70\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00\x02\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual.unpack(data) + assert len(actual) == 52 + assert actual['structure_size'].get_value() == 49 + assert actual['data_offset'].get_value() == 112 + assert actual['length'].get_value() == 4 + assert actual['offset'].get_value() == 131072 + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['channel'].get_value() == 0 + assert actual['remaining_bytes'].get_value() == 0 + assert actual['write_channel_info_offset'].get_value() == 0 + assert actual['write_channel_info_length'].get_value() == 0 + assert actual['flags'].get_value() == 0 + assert actual['buffer'].get_value() == b"\x01\x02\x03\x04" + assert actual['buffer_channel_info'].get_value() == b"" + + def test_parse_message_channel_info(self): + actual = SMB2WriteRequest() + data = b"\x31\x00" \ + b"\x70\x00" \ + b"\x04\x00\x00\x00" \ + b"\x00\x00\x02\x00\x00\x00\x00\x00" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\xff\xff\xff\xff\xff\xff\xff\xff" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x74\x00" \ + b"\x10\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 68 + assert actual['structure_size'].get_value() == 49 + assert actual['data_offset'].get_value() == 112 + assert actual['length'].get_value() == 4 + assert actual['offset'].get_value() == 131072 + assert actual['file_id'].pack() == b"\xff" * 16 + assert actual['channel'].get_value() == 0 + assert actual['remaining_bytes'].get_value() == 0 + assert actual['write_channel_info_offset'].get_value() == 116 + assert actual['write_channel_info_length'].get_value() == 16 + assert actual['flags'].get_value() == 0 + assert actual['buffer'].get_value() == b"\x01\x02\x03\x04" + assert actual['buffer_channel_info'].get_value() == b"\x00" * 16 + + +class TestSMB2WriteResponse(object): + + def test_create_message(self): + message = SMB2WriteResponse() + message['count'] = 58040 + expected = b"\x11\x00" \ + b"\x00\x00" \ + b"\xb8\xe2\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" + actual = message.pack() + assert len(message) == 16 + assert actual == expected + + def test_parse_message(self): + actual = SMB2WriteResponse() + data = b"\x11\x00" \ + b"\x00\x00" \ + b"\xb8\xe2\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00" \ + b"\x00\x00" + actual.unpack(data) + assert len(actual) == 16 + assert actual['structure_size'].get_value() == 17 + assert actual['reserved'].get_value() == 0 + assert actual['count'].get_value() == 58040 + assert actual['remaining'].get_value() == 0 + assert actual['write_channel_info_offset'].get_value() == 0 + assert actual['write_channel_info_length'].get_value() == 0 + + +class TestSMB2QueryDirectoryRequest(object): + + def test_create_message(self): + message = SMB2QueryDirectoryRequest() + message['file_information_class'] = \ + FileInformationClass.FILE_NAMES_INFORMATION + message['file_id'] = b"\xB6\x73\xE4\x65\x00\x00\x00\x00" \ + b"\x68\xBD\xA1\xCE\x00\x00\x00\x00" + message['output_buffer_length'] = 65536 + message['buffer'] = "*".encode('utf-16-le') + expected = b"\x21\x00" \ + b"\x0C" \ + b"\x00" \ + b"\x00\x00\x00\x00" \ + b"\xB6\x73\xE4\x65\x00\x00\x00\x00" \ + b"\x68\xBD\xA1\xCE\x00\x00\x00\x00" \ + b"\x60\x00" \ + b"\x02\x00" \ + b"\x00\x00\x01\x00" \ + b"\x2A\x00" + actual = message.pack() + assert len(message) == 34 + assert actual == expected + + def test_parse_message(self): + actual = SMB2QueryDirectoryRequest() + data = b"\x21\x00" \ + b"\x0C" \ + b"\x00" \ + b"\x00\x00\x00\x00" \ + b"\xB6\x73\xE4\x65\x00\x00\x00\x00" \ + b"\x68\xBD\xA1\xCE\x00\x00\x00\x00" \ + b"\x60\x00" \ + b"\x02\x00" \ + b"\x00\x00\x01\x00" \ + b"\x2A\x00" + data = actual.unpack(data) + assert len(actual) == 34 + assert data == b"" + assert actual['structure_size'].get_value() == 33 + assert actual['file_information_class'].get_value() == \ + FileInformationClass.FILE_NAMES_INFORMATION + assert actual['flags'].get_value() == 0 + assert actual['file_index'].get_value() == 0 + assert actual['file_id'].get_value() == \ + b"\xB6\x73\xE4\x65\x00\x00\x00\x00" \ + b"\x68\xBD\xA1\xCE\x00\x00\x00\x00" + assert actual['file_name_offset'].get_value() == 96 + assert actual['file_name_length'].get_value() == 2 + assert actual['output_buffer_length'].get_value() == 65536 + assert actual['buffer'].get_value().decode('utf-16-le') == "*" + + +class TestSMB2QueryDirectoryResponse(object): + + def test_create_message(self): + message = SMB2QueryDirectoryResponse() + message['buffer'] = b"\x10\x00\x00\x00\x00\x00\x00\x00" \ + b"\x02\x00\x00\x00\x2E\x00\x00\x00" + expected = b"\x09\x00" \ + b"\x48\x00" \ + b"\x10\x00\x00\x00" \ + b"\x10\x00\x00\x00\x00\x00\x00\x00" \ + b"\x02\x00\x00\x00\x2E\x00\x00\x00" + actual = message.pack() + assert len(message) == 24 + assert actual == expected + + def test_parse_message(self): + actual = SMB2QueryDirectoryResponse() + data = b"\x09\x00" \ + b"\x48\x00" \ + b"\x10\x00\x00\x00" \ + b"\x10\x00\x00\x00\x00\x00\x00\x00" \ + b"\x02\x00\x00\x00\x2E\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 24 + assert data == b"" + assert actual['structure_size'].get_value() == 9 + assert actual['output_buffer_offset'].get_value() == 72 + assert actual['output_buffer_length'].get_value() == 16 + assert actual['buffer'].get_value() == \ + b"\x10\x00\x00\x00\x00\x00\x00\x00" \ + b"\x02\x00\x00\x00\x2E\x00\x00\x00" + + +class TestOpen(object): + + # basic file open tests for each dialect + def test_dialect_2_0_2(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_0_2) + session = Session(connection, smb_real[0], smb_real[1], False) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file.txt") + try: + session.connect() + tree.connect() + + out_cont = open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + assert out_cont is None + assert open.allocation_size == 0 + assert isinstance(open.change_time, datetime) + assert open.create_disposition is None + assert open.create_options is None + assert isinstance(open.creation_time, datetime) + assert open.desired_access is None + assert not open.durable + assert open.durable_timeout is None + assert open.end_of_file == 0 + assert open.file_attributes == 32 + assert isinstance(open.file_id, bytes) + assert open.file_name == "file.txt" + assert open.is_persistent is None + assert isinstance(open.last_access_time, datetime) + assert open.last_disconnect_time == 0 + assert isinstance(open.last_write_time, datetime) + assert open.operation_buckets == [] + assert open.oplock_level == 0 + assert not open.resilient_handle + assert not open.resilient_timeout + assert open.share_mode is None + finally: + connection.disconnect(True) + + def test_dialect_2_1_0(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_1_0) + session = Session(connection, smb_real[0], smb_real[1], False) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file.txt") + try: + session.connect() + tree.connect() + + out_cont = open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + assert out_cont is None + assert open.allocation_size == 0 + assert isinstance(open.change_time, datetime) + assert open.create_disposition is None + assert open.create_options is None + assert isinstance(open.creation_time, datetime) + assert open.desired_access is None + assert not open.durable + assert open.durable_timeout is None + assert open.end_of_file == 0 + assert open.file_attributes == 32 + assert isinstance(open.file_id, bytes) + assert open.file_name == "file.txt" + assert open.is_persistent is None + assert isinstance(open.last_access_time, datetime) + assert open.last_disconnect_time == 0 + assert isinstance(open.last_write_time, datetime) + assert open.operation_buckets == [] + assert open.oplock_level == 0 + assert not open.resilient_handle + assert not open.resilient_timeout + assert open.share_mode is None + finally: + connection.disconnect(True) + + def test_dialect_3_0_0(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_0) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file.txt") + try: + session.connect() + tree.connect() + + out_cont = open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + assert out_cont is None + assert open.allocation_size == 0 + assert isinstance(open.change_time, datetime) + assert open.create_disposition is \ + CreateDisposition.FILE_OVERWRITE_IF + assert open.create_options is CreateOptions.FILE_NON_DIRECTORY_FILE + assert isinstance(open.creation_time, datetime) + assert open.desired_access is \ + FilePipePrinterAccessMask.MAXIMUM_ALLOWED + assert not open.durable + assert open.durable_timeout is None + assert open.end_of_file == 0 + assert open.file_attributes == 32 + assert isinstance(open.file_id, bytes) + assert open.file_name == "file.txt" + assert open.is_persistent is None + assert isinstance(open.last_access_time, datetime) + assert open.last_disconnect_time == 0 + assert isinstance(open.last_write_time, datetime) + assert open.operation_buckets == [] + assert open.oplock_level == 0 + assert not open.resilient_handle + assert not open.resilient_timeout + assert open.share_mode == 0 + finally: + connection.disconnect(True) + + def test_dialect_3_0_2(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file.txt") + try: + session.connect() + tree.connect() + + out_cont = open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + assert out_cont is None + assert open.allocation_size == 0 + assert isinstance(open.change_time, datetime) + assert open.create_disposition is \ + CreateDisposition.FILE_OVERWRITE_IF + assert open.create_options is CreateOptions.FILE_NON_DIRECTORY_FILE + assert isinstance(open.creation_time, datetime) + assert open.desired_access is \ + FilePipePrinterAccessMask.MAXIMUM_ALLOWED + assert not open.durable + assert open.durable_timeout is None + assert open.end_of_file == 0 + assert open.file_attributes == 32 + assert isinstance(open.file_id, bytes) + assert open.file_name == "file.txt" + assert open.is_persistent is None + assert isinstance(open.last_access_time, datetime) + assert open.last_disconnect_time == 0 + assert isinstance(open.last_write_time, datetime) + assert open.operation_buckets == [] + assert open.oplock_level == 0 + assert not open.resilient_handle + assert not open.resilient_timeout + assert open.share_mode == 0 + finally: + connection.disconnect(True) + + def test_dialect_3_1_1(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_1_1) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file.txt") + try: + session.connect() + tree.connect() + + out_cont = open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + assert out_cont is None + assert open.allocation_size == 0 + assert isinstance(open.change_time, datetime) + assert open.create_disposition is \ + CreateDisposition.FILE_OVERWRITE_IF + assert open.create_options is CreateOptions.FILE_NON_DIRECTORY_FILE + assert isinstance(open.creation_time, datetime) + assert open.desired_access is \ + FilePipePrinterAccessMask.MAXIMUM_ALLOWED + assert not open.durable + assert open.durable_timeout is None + assert open.end_of_file == 0 + assert open.file_attributes == 32 + assert isinstance(open.file_id, bytes) + assert open.file_name == "file.txt" + assert open.is_persistent is None + assert isinstance(open.last_access_time, datetime) + assert open.last_disconnect_time == 0 + assert isinstance(open.last_write_time, datetime) + assert open.operation_buckets == [] + assert open.oplock_level == 0 + assert not open.resilient_handle + assert not open.resilient_timeout + assert open.share_mode == 0 + finally: + connection.disconnect(True) + + # test more file operations here + def test_create_directory(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_0) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[5]) + open = Open(tree, "folder") + try: + session.connect() + tree.connect() + + out_cont = open.open(ImpersonationLevel.Impersonation, + DirectoryAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_DIRECTORY, + 0, + CreateDisposition.FILE_OPEN_IF, + CreateOptions.FILE_DIRECTORY_FILE) + assert out_cont is None + assert open.allocation_size == 0 + assert isinstance(open.change_time, datetime) + assert open.create_disposition is \ + CreateDisposition.FILE_OPEN_IF + assert open.create_options is CreateOptions.FILE_DIRECTORY_FILE + assert isinstance(open.creation_time, datetime) + assert open.desired_access is \ + DirectoryAccessMask.MAXIMUM_ALLOWED + assert not open.durable + assert open.durable_timeout is None + assert open.end_of_file == 0 + assert open.file_attributes == \ + FileAttributes.FILE_ATTRIBUTE_DIRECTORY + assert isinstance(open.file_id, bytes) + assert open.file_name == "folder" + assert open.is_persistent is None + assert isinstance(open.last_access_time, datetime) + assert open.last_disconnect_time == 0 + assert isinstance(open.last_write_time, datetime) + assert open.operation_buckets == [] + assert open.oplock_level == 0 + assert not open.resilient_handle + assert not open.resilient_timeout + assert open.share_mode == 0 + finally: + connection.disconnect(True) + + def test_create_file_create_contexts(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_0) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[5]) + open = Open(tree, "file-cont.txt") + try: + session.connect() + tree.connect() + + alloc_size = SMB2CreateAllocationSize() + alloc_size['allocation_size'] = 1024 + + alloc_size_context = SMB2CreateContextRequest() + alloc_size_context['buffer_name'] = \ + CreateContextName.SMB2_CREATE_ALLOCATION_SIZE + alloc_size_context['buffer_data'] = alloc_size + + query_disk = SMB2CreateContextRequest() + query_disk['buffer_name'] = \ + CreateContextName.SMB2_CREATE_QUERY_ON_DISK_ID + + max_req_data = SMB2CreateQueryMaximalAccessRequest() + max_req = SMB2CreateContextRequest() + max_req['buffer_name'] = \ + CreateContextName.SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST + max_req['buffer_data'] = max_req_data + + create_contexts = [ + alloc_size_context, + query_disk, + max_req + ] + out_cont = open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE, + create_contexts) + assert len(out_cont) == 2 + assert isinstance(out_cont[0], + SMB2CreateQueryMaximalAccessResponse) or \ + isinstance(out_cont[0], SMB2CreateQueryOnDiskIDResponse) + assert isinstance(out_cont[1], + SMB2CreateQueryMaximalAccessResponse) or \ + isinstance(out_cont[1], SMB2CreateQueryOnDiskIDResponse) + finally: + connection.disconnect(True) + + def test_create_read_write_from_file(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + actual = open.write(b"\x01\x02\x03\x04") + assert actual == 4 + actual = open.read(0, 4) + assert actual == b"\x01\x02\x03\x04" + finally: + connection.disconnect(True) + + def test_flush_file(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[5]) + open = Open(tree, "file-cont.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + open.flush() + finally: + connection.disconnect(True) + + def test_close_file_dont_get_attributes(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect() + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + old_last_write_time = open.last_write_time + old_end_of_file = open.end_of_file + open.write(b"\x01") + open.close(False) + assert open.last_write_time == old_last_write_time + assert open.end_of_file == old_end_of_file + finally: + open.close(False) # test close when it has already been closed + connection.disconnect(True) + + def test_close_file_get_attributes(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect() + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + old_last_write_time = open.last_write_time + old_end_of_file = open.end_of_file + open.write(b"\x01") + open.close(True) + assert open.last_write_time != old_last_write_time + assert open.end_of_file != old_end_of_file + assert open.end_of_file == 1 + finally: + connection.disconnect(True) + + def test_read_file_unbuffered(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + open.write(b"\x01") + actual = open.read(0, 1, unbuffered=True) + assert actual == b"\x01" + finally: + connection.disconnect(True) + + def test_read_file_unbuffered_unsupported(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_0) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + open.write(b"\x01") + with pytest.raises(SMBUnsupportedFeature) as exc: + open.read(0, 1, unbuffered=True) + assert exc.value.feature_name == "SMB2_READFLAG_READ_UNBUFFERED" + assert exc.value.negotiated_dialect == Dialects.SMB_3_0_0 + assert exc.value.required_dialect == Dialects.SMB_3_0_2 + assert exc.value.requires_newer + assert str(exc.value) == \ + "SMB2_READFLAG_READ_UNBUFFERED is not available on the " \ + "negotiated dialect (768) SMB_3_0_0, requires dialect (770) " \ + "SMB_3_0_2 or newer" + finally: + connection.disconnect(True) + + @pytest.mark.skipif(os.name == "nt", + reason="write-through writes don't work on windows?") + def test_write_file_write_through(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE | + CreateOptions.FILE_WRITE_THROUGH) + + actual = open.write(b"\x01", write_through=True) + assert actual == 1 + actual = open.read(0, 1) + assert actual == b"\x01" + finally: + connection.disconnect(True) + + def test_write_file_write_through_unsupported(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_0_2) + session = Session(connection, smb_real[0], smb_real[1], False) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE | + CreateOptions.FILE_WRITE_THROUGH) + + with pytest.raises(SMBUnsupportedFeature) as exc: + open.write(b"\x01", write_through=True) + assert exc.value.feature_name == "SMB2_WRITEFLAG_WRITE_THROUGH" + assert exc.value.negotiated_dialect == Dialects.SMB_2_0_2 + assert exc.value.required_dialect == Dialects.SMB_2_1_0 + assert exc.value.requires_newer + assert str(exc.value) == \ + "SMB2_WRITEFLAG_WRITE_THROUGH is not available on the " \ + "negotiated dialect (514) SMB_2_0_2, requires dialect (528) " \ + "SMB_2_1_0 or newer" + finally: + connection.disconnect(True) + + @pytest.mark.skipif(os.name == "nt", + reason="unbufferred writes don't work on windows?") + def test_write_file_unbuffered(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE | + CreateOptions.FILE_NO_INTERMEDIATE_BUFFERING) + + actual = open.write(b"\x01", unbuffered=True) + assert actual == 1 + actual = open.read(0, 1) + assert actual == b"\x01" + finally: + connection.disconnect(True) + + def test_write_file_unbuffered_unsupported(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_1_0) + session = Session(connection, smb_real[0], smb_real[1], False) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE | + CreateOptions.FILE_NO_INTERMEDIATE_BUFFERING) + + with pytest.raises(SMBUnsupportedFeature) as exc: + open.write(b"\x01", unbuffered=True) + assert exc.value.feature_name == "SMB2_WRITEFLAG_WRITE_UNBUFFERED" + assert exc.value.negotiated_dialect == Dialects.SMB_2_1_0 + assert exc.value.required_dialect == Dialects.SMB_3_0_2 + assert exc.value.requires_newer + assert str(exc.value) == \ + "SMB2_WRITEFLAG_WRITE_UNBUFFERED is not available on the " \ + "negotiated dialect (528) SMB_2_1_0, requires dialect (770) " \ + "SMB_3_0_2 or newer" + finally: + connection.disconnect(True) + + def test_query_directory(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "directory") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + DirectoryAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_DIRECTORY, + ShareAccess.FILE_SHARE_READ | + ShareAccess.FILE_SHARE_WRITE | + ShareAccess.FILE_SHARE_DELETE, + CreateDisposition.FILE_OPEN_IF, + CreateOptions.FILE_DIRECTORY_FILE) + + file1 = Open(tree, r"directory\\file1.txt") + file1.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + ShareAccess.FILE_SHARE_READ, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + file1.write(b"\x01\x02\x03\x04", 0) + + file2 = Open(tree, r"directory\\file2.log") + file2.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + ShareAccess.FILE_SHARE_READ, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + file2.write(b"\x05\x06", 0) + + actual = open.query_directory("*", + FileInformationClass. + FILE_NAMES_INFORMATION) + + assert len(actual) == 4 + assert isinstance(actual[0], FileNamesInformation) + assert actual[0]['file_name'].get_value().decode('utf-16-le') == \ + "." + assert isinstance(actual[1], FileNamesInformation) + assert actual[1]['file_name'].get_value().decode('utf-16-le') == \ + ".." + + file1_name = "file1.txt".encode('utf-16-le') + file2_name = "file2.log".encode('utf-16-le') + assert isinstance(actual[2], FileNamesInformation) + assert actual[2]['file_name'].get_value() in \ + [file1_name, file2_name] + assert isinstance(actual[3], FileNamesInformation) + assert actual[3]['file_name'].get_value() in \ + [file1_name, file2_name] + + open.close() + finally: + connection.disconnect(True) + + @pytest.mark.skipif(os.name == "nt", + reason="flush in compound does't work on windows") + def test_compounding_open_requests(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "directory") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + DirectoryAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_DIRECTORY, + ShareAccess.FILE_SHARE_READ | + ShareAccess.FILE_SHARE_WRITE | + ShareAccess.FILE_SHARE_DELETE, + CreateDisposition.FILE_OPEN_IF, + CreateOptions.FILE_DIRECTORY_FILE) + + file1 = Open(tree, r"directory\\file1.txt") + file1.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + ShareAccess.FILE_SHARE_READ, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + file2 = Open(tree, r"directory\\file2.log") + file2.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + ShareAccess.FILE_SHARE_READ, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + # create messages for each operation + messages = [ + file1.write(b"\x01\x02\x03\x04", 0, send=False), + file2.write(b"\x05\x06", 0, send=False), + file1.flush(send=False), + file1.read(0, 4, send=False), + open.query_directory("*", + FileInformationClass. + FILE_ID_BOTH_DIRECTORY_INFORMATION, + send=False), + file1.close(send=False), + file2.close(send=False) + ] + + # send each message as a compound request + requests = connection.send_compound([x[0] for x in messages], + session.session_id, + tree.tree_connect_id) + + # get responses and run unpack function + responses = [] + for i, request in enumerate(requests): + response = messages[i][1](request) + responses.append(response) + + # assert each response + assert len(responses) == 7 + assert isinstance(responses[0], int) + assert isinstance(responses[1], int) + assert isinstance(responses[2], SMB2FlushResponse) + assert isinstance(responses[3], bytes) + assert isinstance(responses[4], list) + assert isinstance(responses[5], SMB2CloseResponse) + assert isinstance(responses[6], SMB2CloseResponse) + + write1 = responses[0] + assert write1 == 4 + + write2 = responses[1] + assert write2 == 2 + + read1 = responses[3] + assert read1 == b"\x01\x02\x03\x04" + + query1 = responses[4] + assert query1[0]['file_name'].get_value() == \ + ".".encode('utf-16-le') + assert query1[1]['file_name'].get_value() == \ + "..".encode('utf-16-le') + file1_name = "file1.txt".encode('utf-16-le') + file2_name = "file2.log".encode('utf-16-le') + assert query1[2]['file_name'].get_value() \ + in [file1_name, file2_name] + assert query1[3]['file_name'].get_value() \ + in [file1_name, file2_name] + + open.close() + finally: + connection.disconnect(True) + + @pytest.mark.skipif(os.name == "nt", + reason="flush in compound does't work on windows") + def test_compounding_open_requests_unencrypted(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_1_0) + session = Session(connection, smb_real[0], smb_real[1], False) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "directory") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + DirectoryAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_DIRECTORY, + ShareAccess.FILE_SHARE_READ | + ShareAccess.FILE_SHARE_WRITE | + ShareAccess.FILE_SHARE_DELETE, + CreateDisposition.FILE_OPEN_IF, + CreateOptions.FILE_DIRECTORY_FILE) + + file1 = Open(tree, r"directory\\file1.txt") + file1.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + ShareAccess.FILE_SHARE_READ, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + file2 = Open(tree, r"directory\\file2.log") + file2.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + ShareAccess.FILE_SHARE_READ, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + # create messages for each operation + messages = [ + file1.write(b"\x01\x02\x03\x04", 0, send=False), + file2.write(b"\x05\x06", 0, send=False), + file1.flush(send=False), + file1.read(0, 4, send=False), + open.query_directory("*", + FileInformationClass. + FILE_ID_BOTH_DIRECTORY_INFORMATION, + send=False), + file1.close(send=False), + file2.close(send=False) + ] + + # send each message as a compound request + requests = connection.send_compound([x[0] for x in messages], + session.session_id, + tree.tree_connect_id) + + # get responses and run unpack function + responses = [] + for i, request in enumerate(requests): + response = messages[i][1](request) + responses.append(response) + + # assert each response + assert len(responses) == 7 + assert isinstance(responses[0], int) + assert isinstance(responses[1], int) + assert isinstance(responses[2], SMB2FlushResponse) + assert isinstance(responses[3], bytes) + assert isinstance(responses[4], list) + assert isinstance(responses[5], SMB2CloseResponse) + assert isinstance(responses[6], SMB2CloseResponse) + + write1 = responses[0] + assert write1 == 4 + + write2 = responses[1] + assert write2 == 2 + + read1 = responses[3] + assert read1 == b"\x01\x02\x03\x04" + + query1 = responses[4] + assert query1[0]['file_name'].get_value() == \ + ".".encode('utf-16-le') + assert query1[1]['file_name'].get_value() == \ + "..".encode('utf-16-le') + file1_name = "file1.txt".encode('utf-16-le') + file2_name = "file2.log".encode('utf-16-le') + assert query1[2]['file_name'].get_value() \ + in [file1_name, file2_name] + assert query1[3]['file_name'].get_value() \ + in [file1_name, file2_name] + + open.close() + finally: + connection.disconnect(True) + + def test_close_file_already_closed(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file-read-write.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + open.close() + + # we will just manually say it is still connected so we get the + # proper error msg + open._connected = True + open.close() + finally: + connection.disconnect(True) + + def test_read_greater_than_max_size(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect() + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + with pytest.raises(SMBException) as exc: + open.read(0, connection.max_read_size + 1) + assert str(exc.value) == "The requested read length %d is " \ + "greater than the maximum negotiated " \ + "read size %d"\ + % (connection.max_read_size + 1, connection.max_read_size) + finally: + connection.disconnect(True) + + def test_write_greater_than_max_size(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect() + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + + with pytest.raises(SMBException) as exc: + open.write(b"\x00" * (connection.max_write_size + 1), 0) + assert str(exc.value) == "The requested write length %d is " \ + "greater than the maximum negotiated " \ + "write size %d"\ + % (connection.max_write_size + 1, connection.max_write_size) + finally: + connection.disconnect(True) + + def test_read_file_multi_credits(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect() + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + open = Open(tree, "file.txt") + try: + session.connect() + tree.connect() + + open.open(ImpersonationLevel.Impersonation, + FilePipePrinterAccessMask.MAXIMUM_ALLOWED, + FileAttributes.FILE_ATTRIBUTE_NORMAL, + 0, + CreateDisposition.FILE_OVERWRITE_IF, + CreateOptions.FILE_NON_DIRECTORY_FILE) + open.write(b"\x01\x02\x03\x04", 0) + actual = open.read(0, 65538) + assert actual == b"\x01\x02\x03\x04" + finally: + connection.disconnect(True) diff --git a/tests/test_query_info.py b/tests/test_query_info.py new file mode 100644 index 00000000..852adbbd --- /dev/null +++ b/tests/test_query_info.py @@ -0,0 +1,423 @@ +from datetime import datetime + +from smbprotocol.query_info import FileBothDirectoryInformation, \ + FileDirectoryInformation, FileFullDirectoryInformation, \ + FileIdBothDirectoryInformation, FileIdFullDirectoryInformation, \ + FileNamesInformation + + +class TestFileBothDirectoryInformation(object): + + def test_create_message(self): + message = FileBothDirectoryInformation() + message['creation_time'] = datetime.utcfromtimestamp(1024) + message['last_access_time'] = datetime.utcfromtimestamp(1024) + message['last_write_time'] = datetime.utcfromtimestamp(1024) + message['change_time'] = datetime.utcfromtimestamp(1024) + message['end_of_file'] = 4 + message['allocation_size'] = 1048576 + message['file_attributes'] = 32 + message['file_name'] = "file1.txt".encode("utf-16-le") + + expected = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + actual = message.pack() + assert len(message) == 112 + assert actual == expected + + def test_parse_message(self): + actual = FileBothDirectoryInformation() + data = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + data = actual.unpack(data) + assert len(actual) == 112 + assert data == b"" + assert actual['next_entry_offset'].get_value() == 0 + assert actual['file_index'].get_value() == 0 + assert actual['creation_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_access_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_write_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['change_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['end_of_file'].get_value() == 4 + assert actual['allocation_size'].get_value() == 1048576 + assert actual['file_attributes'].get_value() == 32 + assert actual['file_name_length'].get_value() == 18 + assert actual['ea_size'].get_value() == 0 + assert actual['short_name_length'].get_value() == 0 + assert actual['reserved'].get_value() == 0 + assert actual['short_name'].get_value() == b"" + assert actual['short_name_padding'].get_value() == b"\x00" * 24 + assert actual['file_name'].get_value() == \ + "file1.txt".encode('utf-16-le') + + +class TestFileDirectoryInformation(object): + + def test_create_message(self): + message = FileDirectoryInformation() + message['creation_time'] = datetime.utcfromtimestamp(1024) + message['last_access_time'] = datetime.utcfromtimestamp(1024) + message['last_write_time'] = datetime.utcfromtimestamp(1024) + message['change_time'] = datetime.utcfromtimestamp(1024) + message['end_of_file'] = 4 + message['allocation_size'] = 1048576 + message['file_attributes'] = 32 + message['file_name'] = "file1.txt".encode("utf-16-le") + + expected = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + actual = message.pack() + assert len(message) == 82 + assert actual == expected + + def test_parse_message(self): + actual = FileDirectoryInformation() + data = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + data = actual.unpack(data) + assert len(actual) == 82 + assert data == b"" + assert actual['next_entry_offset'].get_value() == 0 + assert actual['file_index'].get_value() == 0 + assert actual['creation_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_access_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_write_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['change_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['end_of_file'].get_value() == 4 + assert actual['allocation_size'].get_value() == 1048576 + assert actual['file_attributes'].get_value() == 32 + assert actual['file_name_length'].get_value() == 18 + assert actual['file_name'].get_value() == \ + "file1.txt".encode('utf-16-le') + + +class TestFileFullDirectoryInformation(object): + + def test_create_message(self): + message = FileFullDirectoryInformation() + message['creation_time'] = datetime.utcfromtimestamp(1024) + message['last_access_time'] = datetime.utcfromtimestamp(1024) + message['last_write_time'] = datetime.utcfromtimestamp(1024) + message['change_time'] = datetime.utcfromtimestamp(1024) + message['end_of_file'] = 4 + message['allocation_size'] = 1048576 + message['file_attributes'] = 32 + message['file_name'] = "file1.txt".encode("utf-16-le") + + expected = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + actual = message.pack() + assert len(message) == 86 + assert actual == expected + + def test_parse_message(self): + actual = FileFullDirectoryInformation() + data = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + data = actual.unpack(data) + assert len(actual) == 86 + assert data == b"" + assert actual['next_entry_offset'].get_value() == 0 + assert actual['file_index'].get_value() == 0 + assert actual['creation_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_access_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_write_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['change_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['end_of_file'].get_value() == 4 + assert actual['allocation_size'].get_value() == 1048576 + assert actual['file_attributes'].get_value() == 32 + assert actual['file_name_length'].get_value() == 18 + assert actual['ea_size'].get_value() == 0 + assert actual['file_name'].get_value() == \ + "file1.txt".encode('utf-16-le') + + +class TestFileIdBothDirectoryInformation(object): + + def test_create_message(self): + message = FileIdBothDirectoryInformation() + message['creation_time'] = datetime.utcfromtimestamp(1024) + message['last_access_time'] = datetime.utcfromtimestamp(1024) + message['last_write_time'] = datetime.utcfromtimestamp(1024) + message['change_time'] = datetime.utcfromtimestamp(1024) + message['end_of_file'] = 4 + message['allocation_size'] = 1048576 + message['file_attributes'] = 32 + message['file_id'] = 8800388263864 + message['file_name'] = "file1.txt".encode("utf-16-le") + + expected = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00" \ + b"\xB8\x2F\x04\x00\x01\x08\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + actual = message.pack() + assert len(message) == 122 + assert actual == expected + + def test_parse_message(self): + actual = FileIdBothDirectoryInformation() + data = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00" \ + b"\xB8\x2F\x04\x00\x01\x08\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + data = actual.unpack(data) + assert len(actual) == 122 + assert data == b"" + assert actual['next_entry_offset'].get_value() == 0 + assert actual['file_index'].get_value() == 0 + assert actual['creation_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_access_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_write_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['change_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['end_of_file'].get_value() == 4 + assert actual['allocation_size'].get_value() == 1048576 + assert actual['file_attributes'].get_value() == 32 + assert actual['file_name_length'].get_value() == 18 + assert actual['ea_size'].get_value() == 0 + assert actual['short_name_length'].get_value() == 0 + assert actual['reserved1'].get_value() == 0 + assert actual['short_name'].get_value() == b"" + assert actual['short_name_padding'].get_value() == b"\x00" * 24 + assert actual['reserved2'].get_value() == 0 + assert actual['file_id'].get_value() == 8800388263864 + assert actual['file_name'].get_value() == \ + "file1.txt".encode('utf-16-le') + + +class TestFileIdFullDirectoryInformation(object): + + def test_create_message(self): + message = FileIdFullDirectoryInformation() + message['creation_time'] = datetime.utcfromtimestamp(1024) + message['last_access_time'] = datetime.utcfromtimestamp(1024) + message['last_write_time'] = datetime.utcfromtimestamp(1024) + message['change_time'] = datetime.utcfromtimestamp(1024) + message['end_of_file'] = 4 + message['allocation_size'] = 1048576 + message['file_attributes'] = 32 + message['file_id'] = 8800388263864 + message['file_name'] = "file1.txt".encode("utf-16-le") + + expected = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xB8\x2F\x04\x00\x01\x08\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + actual = message.pack() + assert len(message) == 98 + assert actual == expected + + def test_parse_message(self): + actual = FileIdFullDirectoryInformation() + data = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x00\x80\x98\x37\xe1\xb1\x9d\x01" \ + b"\x04\x00\x00\x00\x00\x00\x00\x00" \ + b"\x00\x00\x10\x00\x00\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\xB8\x2F\x04\x00\x01\x08\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + data = actual.unpack(data) + assert len(actual) == 98 + assert data == b"" + assert actual['next_entry_offset'].get_value() == 0 + assert actual['file_index'].get_value() == 0 + assert actual['creation_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_access_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['last_write_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['change_time'].get_value() == \ + datetime.utcfromtimestamp(1024) + assert actual['end_of_file'].get_value() == 4 + assert actual['allocation_size'].get_value() == 1048576 + assert actual['file_attributes'].get_value() == 32 + assert actual['file_name_length'].get_value() == 18 + assert actual['ea_size'].get_value() == 0 + assert actual['reserved'].get_value() == 0 + assert actual['file_id'].get_value() == 8800388263864 + assert actual['file_name'].get_value() == \ + "file1.txt".encode('utf-16-le') + + +class TestFileNamesInformation(object): + + def test_create_message(self): + message = FileNamesInformation() + message['file_name'] = "file1.txt".encode('utf-16-le') + expected = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + actual = message.pack() + assert len(message) == 30 + assert actual == expected + + def test_parse_message(self): + actual = FileNamesInformation() + data = b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x12\x00\x00\x00" \ + b"\x66\x00\x69\x00\x6C\x00\x65\x00" \ + b"\x31\x00\x2E\x00\x74\x00\x78\x00" \ + b"\x74\x00" + data = actual.unpack(data) + assert len(actual) == 30 + assert data == b"" + assert actual['next_entry_offset'].get_value() == 0 + assert actual['file_index'].get_value() == 0 + assert actual['file_name_length'].get_value() == 18 + assert actual['file_name'].get_value() == \ + "file1.txt".encode('utf-16-le') diff --git a/tests/test_security_descriptor.py b/tests/test_security_descriptor.py new file mode 100644 index 00000000..4d841716 --- /dev/null +++ b/tests/test_security_descriptor.py @@ -0,0 +1,633 @@ +import pytest + +from smbprotocol.security_descriptor import AccessAllowedAce, \ + AccessDeniedAce, AceType, AclPacket, AclRevision, SDControl, SIDPacket, \ + SMB2CreateSDBuffer, SystemAuditAce + + +class TestSIDPacket(object): + + def test_create_message(self): + sid = "S-1-1-0" + message = SIDPacket() + message.from_string(sid) + expected = b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 12 + assert actual == expected + assert str(message) == sid + + def test_create_domain_sid(self): + sid = "S-1-5-21-3242954042-3778974373-1659123385-1104" + message = SIDPacket() + message.from_string(sid) + expected = b"\x01" \ + b"\x05" \ + b"\x00\x00" \ + b"\x00\x00\x00\x05" \ + b"\x15\x00\x00\x00" \ + b"\x3a\x8d\x4b\xc1" \ + b"\xa5\x92\x3e\xe1" \ + b"\xb9\x36\xe4\x62" \ + b"\x50\x04\x00\x00" + actual = message.pack() + assert len(message) == 28 + assert actual == expected + assert str(message) == sid + + def test_parse_string_fail_no_s(self): + sid = SIDPacket() + with pytest.raises(ValueError) as exc: + sid.from_string("A-1-1-0") + assert str(exc.value) == "A SID string must start with S-" + + def test_parse_string_fail_too_small(self): + sid = SIDPacket() + with pytest.raises(ValueError) as exc: + sid.from_string("S-1") + assert str(exc.value) == "A SID string must start with S and contain" \ + " a revision and identifier authority, e.g." \ + " S-1-0" + + def test_parse_message(self): + actual = SIDPacket() + data = b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 12 + assert str(actual) == "S-1-1-0" + assert actual['revision'].get_value() == 1 + assert actual['sub_authority_count'].get_value() == 1 + assert actual['reserved'].get_value() == 0 + assert actual['identifier_authority'].get_value() == 1 + sub_auth = actual['sub_authorities'].get_value() + assert isinstance(sub_auth, list) + assert len(sub_auth) == 1 + assert sub_auth[0] == 0 + + def test_parse_message_domain_sid(self): + actual = SIDPacket() + data = b"\x01" \ + b"\x05" \ + b"\x00\x00" \ + b"\x00\x00\x00\x05" \ + b"\x15\x00\x00\x00" \ + b"\x3a\x8d\x4b\xc1" \ + b"\xa5\x92\x3e\xe1" \ + b"\xb9\x36\xe4\x62" \ + b"\x50\x04\x00\x00" + actual.unpack(data) + assert len(actual) == 28 + assert str(actual) == "S-1-5-21-3242954042-3778974373-1659123385-1104" + assert actual['revision'].get_value() == 1 + assert actual['sub_authority_count'].get_value() == 5 + assert actual['reserved'].get_value() == 0 + assert actual['identifier_authority'].get_value() == 5 + sub_auth = actual['sub_authorities'].get_value() + assert isinstance(sub_auth, list) + assert len(sub_auth) == 5 + assert sub_auth[0] == 21 + assert sub_auth[1] == 3242954042 + assert sub_auth[2] == 3778974373 + assert sub_auth[3] == 1659123385 + assert sub_auth[4] == 1104 + + +class TestAccessAllowedAce(object): + + def test_create_message(self): + sid = SIDPacket() + sid.from_string("S-1-1-0") + + message = AccessAllowedAce() + message['mask'] = 2032127 + message['sid'] = sid + expected = b"\x00" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 20 + assert actual == expected + + def test_parse_message(self): + actual = AccessAllowedAce() + data = b"\x00" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 20 + assert data == b"" + assert actual['ace_type'].get_value() == \ + AceType.ACCESS_ALLOWED_ACE_TYPE + assert actual['ace_flags'].get_value() == 0 + assert actual['ace_size'].get_value() == 20 + assert actual['mask'].get_value() == 2032127 + assert str(actual['sid'].get_value()) == "S-1-1-0" + + +class TestAccessDeniedAce(object): + + def test_create_message(self): + sid = SIDPacket() + sid.from_string("S-1-1-0") + + message = AccessDeniedAce() + message['mask'] = 2032127 + message['sid'] = sid + expected = b"\x01" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 20 + assert actual == expected + + def test_parse_message(self): + actual = AccessDeniedAce() + data = b"\x01" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 20 + assert data == b"" + assert actual['ace_type'].get_value() == AceType.ACCESS_DENIED_ACE_TYPE + assert actual['ace_flags'].get_value() == 0 + assert actual['ace_size'].get_value() == 20 + assert actual['mask'].get_value() == 2032127 + assert str(actual['sid'].get_value()) == "S-1-1-0" + + +class TestSystemAuditAce(object): + + def test_create_message(self): + sid = SIDPacket() + sid.from_string("S-1-1-0") + + message = SystemAuditAce() + message['mask'] = 2032127 + message['sid'] = sid + expected = b"\x02" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 20 + assert actual == expected + + def test_parse_message(self): + actual = SystemAuditAce() + data = b"\x02" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + data = actual.unpack(data) + assert len(actual) == 20 + assert data == b"" + assert actual['ace_type'].get_value() == AceType.SYSTEM_AUDIT_ACE_TYPE + assert actual['ace_flags'].get_value() == 0 + assert actual['ace_size'].get_value() == 20 + assert actual['mask'].get_value() == 2032127 + assert str(actual['sid'].get_value()) == "S-1-1-0" + + +class TestAclPacket(object): + + def test_create_message(self): + sid1 = SIDPacket() + sid1.from_string("S-1-1-0") + sid2 = SIDPacket() + sid2.from_string("S-1-5-21-3242954042-3778974373-1659123385-1104") + + ace1 = AccessAllowedAce() + ace1['mask'] = 2032127 + ace1['sid'] = sid1 + ace2 = AccessAllowedAce() + ace2['mask'] = 2032127 + ace2['sid'] = sid2 + # define an illegal ACE for tests to see if it is flexible for custom + # aces' + ace3 = AccessAllowedAce() + ace3['ace_type'] = AceType.ACCESS_ALLOWED_OBJECT_ACE_TYPE + ace3['sid'] = sid1 + + message = AclPacket() + message['aces'] = [ + ace1, ace2, ace3.pack() + ] + expected = b"\x02" \ + b"\x00" \ + b"\x54\x00" \ + b"\x03\x00" \ + b"\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x24\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x05" \ + b"\x00\x00" \ + b"\x00\x00\x00\x05" \ + b"\x15\x00\x00\x00" \ + b"\x3a\x8d\x4b\xc1" \ + b"\xa5\x92\x3e\xe1" \ + b"\xb9\x36\xe4\x62" \ + b"\x50\x04\x00\x00" \ + b"\x05" \ + b"\x00" \ + b"\x14\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 84 + assert actual == expected + + def test_parse_message(self): + actual = AclPacket() + data = b"\x02" \ + b"\x00" \ + b"\x54\x00" \ + b"\x03\x00" \ + b"\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x24\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x05" \ + b"\x00\x00" \ + b"\x00\x00\x00\x05" \ + b"\x15\x00\x00\x00" \ + b"\x3a\x8d\x4b\xc1" \ + b"\xa5\x92\x3e\xe1" \ + b"\xb9\x36\xe4\x62" \ + b"\x50\x04\x00\x00" \ + b"\x05" \ + b"\x00" \ + b"\x14\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + + actual.unpack(data) + assert len(actual) == 84 + assert actual['acl_revision'].get_value() == AclRevision.ACL_REVISION + assert actual['sbz1'].get_value() == 0 + assert actual['acl_size'].get_value() == 84 + assert actual['ace_count'].get_value() == 3 + assert actual['sbz2'].get_value() == 0 + aces = actual['aces'].get_value() + assert len(aces) == 3 + + assert aces[0]['ace_type'].get_value() == \ + AceType.ACCESS_ALLOWED_ACE_TYPE + assert aces[0]['ace_flags'].get_value() == 0 + assert aces[0]['ace_size'].get_value() == 20 + assert aces[0]['mask'].get_value() == 2032127 + assert str(aces[0]['sid'].get_value()) == "S-1-1-0" + + assert aces[1]['ace_type'].get_value() == \ + AceType.ACCESS_ALLOWED_ACE_TYPE + assert aces[1]['ace_flags'].get_value() == 0 + assert aces[1]['ace_size'].get_value() == 36 + assert aces[1]['mask'].get_value() == 2032127 + assert str(aces[1]['sid'].get_value()) == \ + "S-1-5-21-3242954042-3778974373-1659123385-1104" + + assert isinstance(aces[2], bytes) + assert aces[2] == b"\x05\x00\x14\x00\x00\x00\x00\x00" \ + b"\x01\x01\x00\x00\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + + +class TestSMB2SDBuffer(object): + + def test_create_message(self): + sid1 = SIDPacket() + sid1.from_string("S-1-1-0") + sid2 = SIDPacket() + sid2.from_string("S-1-5-21-3242954042-3778974373-1659123385-1104") + + ace1 = AccessAllowedAce() + ace1['mask'] = 2032127 + ace1['sid'] = sid1 + ace2 = AccessAllowedAce() + ace2['mask'] = 2032127 + ace2['sid'] = sid2 + acl = AclPacket() + acl['aces'] = [ + ace1, ace2 + ] + + message = SMB2CreateSDBuffer() + message['control'].set_flag(SDControl.SELF_RELATIVE) + message.set_dacl(acl) + message.set_owner(sid2) + message.set_group(sid1) + message.set_sacl(None) + + expected = b"\x01" \ + b"\x00" \ + b"\x04\x80" \ + b"\x54\x00\x00\x00" \ + b"\x70\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00" \ + b"\x02" \ + b"\x00" \ + b"\x40\x00" \ + b"\x02\x00" \ + b"\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x24\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x05" \ + b"\x00\x00" \ + b"\x00\x00\x00\x05" \ + b"\x15\x00\x00\x00" \ + b"\x3a\x8d\x4b\xc1" \ + b"\xa5\x92\x3e\xe1" \ + b"\xb9\x36\xe4\x62" \ + b"\x50\x04\x00\x00" \ + b"\x01\x05" \ + b"\x00\x00" \ + b"\x00\x00\x00\x05" \ + b"\x15\x00\x00\x00" \ + b"\x3a\x8d\x4b\xc1" \ + b"\xa5\x92\x3e\xe1" \ + b"\xb9\x36\xe4\x62" \ + b"\x50\x04\x00\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 124 + assert actual == expected + + def test_create_message_sacl_group(self): + sid = SIDPacket() + sid.from_string("S-1-1-0") + + ace = AccessAllowedAce() + ace['sid'] = sid + acl = AclPacket() + acl['aces'] = [ace] + + message = SMB2CreateSDBuffer() + message.set_dacl(None) + message.set_owner(None) + message.set_group(sid) + message.set_sacl(acl) + + expected = b"\x01" \ + b"\x00" \ + b"\x10\x00" \ + b"\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" \ + b"\x02" \ + b"\x00" \ + b"\x1c\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x14\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual = message.pack() + assert len(message) == 60 + assert actual == expected + + def test_parse_message_sacl_group(self): + actual = SMB2CreateSDBuffer() + data = b"\x01" \ + b"\x00" \ + b"\x10\x00" \ + b"\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00" \ + b"\x20\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" \ + b"\x02" \ + b"\x00" \ + b"\x1c\x00" \ + b"\x01\x00" \ + b"\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x14\x00" \ + b"\x00\x00\x00\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 60 + assert actual['revision'].get_value() == 1 + assert actual['sbz1'].get_value() == 0 + assert actual['control'].get_value() == 16 + assert actual['offset_owner'].get_value() == 0 + assert actual['offset_group'].get_value() == 20 + assert actual['offset_sacl'].get_value() == 32 + assert actual['offset_dacl'].get_value() == 0 + assert len(actual['buffer']) == 40 + + assert not actual.get_owner() + assert str(actual.get_group()) == "S-1-1-0" + sacl = actual.get_sacl() + assert sacl['acl_revision'].get_value() == AclRevision.ACL_REVISION + assert sacl['sbz1'].get_value() == 0 + assert sacl['acl_size'].get_value() == 28 + assert sacl['ace_count'].get_value() == 1 + assert sacl['sbz2'].get_value() == 0 + saces = sacl['aces'].get_value() + assert isinstance(saces, list) + assert len(saces) == 1 + assert saces[0]['ace_type'].get_value() == \ + AceType.ACCESS_ALLOWED_ACE_TYPE + assert saces[0]['ace_flags'].get_value() == 0 + assert saces[0]['ace_size'].get_value() == 20 + assert saces[0]['mask'].get_value() == 0 + assert str(saces[0]['sid']) == "S-1-1-0" + + assert not actual.get_dacl() + + def test_parse_message(self): + actual = SMB2CreateSDBuffer() + data = b"\x01" \ + b"\x00" \ + b"\x04\x80" \ + b"\x54\x00\x00\x00" \ + b"\x70\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x14\x00\x00\x00" \ + b"\x02" \ + b"\x00" \ + b"\x40\x00" \ + b"\x02\x00" \ + b"\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x14\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" \ + b"\x00" \ + b"\x00" \ + b"\x24\x00" \ + b"\xff\x01\x1f\x00" \ + b"\x01" \ + b"\x05" \ + b"\x00\x00" \ + b"\x00\x00\x00\x05" \ + b"\x15\x00\x00\x00" \ + b"\x3a\x8d\x4b\xc1" \ + b"\xa5\x92\x3e\xe1" \ + b"\xb9\x36\xe4\x62" \ + b"\x50\x04\x00\x00" \ + b"\x01\x05" \ + b"\x00\x00" \ + b"\x00\x00\x00\x05" \ + b"\x15\x00\x00\x00" \ + b"\x3a\x8d\x4b\xc1" \ + b"\xa5\x92\x3e\xe1" \ + b"\xb9\x36\xe4\x62" \ + b"\x50\x04\x00\x00" \ + b"\x01" \ + b"\x01" \ + b"\x00\x00" \ + b"\x00\x00\x00\x01" \ + b"\x00\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 124 + assert actual['revision'].get_value() == 1 + assert actual['sbz1'].get_value() == 0 + assert actual['control'].get_value() == 32772 + assert actual['offset_owner'].get_value() == 84 + assert actual['offset_group'].get_value() == 112 + assert actual['offset_sacl'].get_value() == 0 + assert actual['offset_dacl'].get_value() == 20 + assert len(actual['buffer']) == 104 + + assert str(actual.get_owner()) == \ + "S-1-5-21-3242954042-3778974373-1659123385-1104" + assert str(actual.get_group()) == "S-1-1-0" + assert not actual.get_sacl() + dacl = actual.get_dacl() + assert dacl['acl_revision'].get_value() == AclRevision.ACL_REVISION + assert dacl['sbz1'].get_value() == 0 + assert dacl['acl_size'].get_value() == 64 + assert dacl['ace_count'].get_value() == 2 + assert dacl['sbz2'].get_value() == 0 + daces = dacl['aces'].get_value() + assert isinstance(daces, list) + assert len(daces) == 2 + assert daces[0]['ace_type'].get_value() == \ + AceType.ACCESS_ALLOWED_ACE_TYPE + assert daces[0]['ace_flags'].get_value() == 0 + assert daces[0]['ace_size'].get_value() == 20 + assert daces[0]['mask'].get_value() == 2032127 + assert str(daces[0]['sid']) == "S-1-1-0" + assert daces[1]['ace_type'].get_value() == \ + AceType.ACCESS_ALLOWED_ACE_TYPE + assert daces[1]['ace_flags'].get_value() == 0 + assert daces[1]['ace_size'].get_value() == 36 + assert daces[1]['mask'].get_value() == 2032127 + assert str(daces[1]['sid']) == \ + "S-1-5-21-3242954042-3778974373-1659123385-1104" diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 00000000..692f5fcf --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,302 @@ +import uuid + +import pytest + +from smbprotocol.connection import Connection, Dialects, SecurityMode +from smbprotocol.exceptions import SMBAuthenticationError, SMBException +from smbprotocol.session import Session, SMB2Logoff, SMB2SessionSetupRequest, \ + SMB2SessionSetupResponse + +from .utils import smb_real + + +class TestSMB2SessionSetupRequest(object): + + def test_create_message(self): + message = SMB2SessionSetupRequest() + message['security_mode'] = SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED + message['buffer'] = b"\x01\x02\x03\x04" + expected = b"\x19\x00" \ + b"\x00" \ + b"\x01" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x58\x00" \ + b"\x04\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual = message.pack() + assert len(message) == 28 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SessionSetupRequest() + data = b"\x19\x00" \ + b"\x00" \ + b"\x01" \ + b"\x00\x00\x00\x00" \ + b"\x00\x00\x00\x00" \ + b"\x58\x00" \ + b"\x04\x00" \ + b"\x00\x00\x00\x00\x00\x00\x00\x00" \ + b"\x01\x02\x03\x04" + actual.unpack(data) + assert len(actual) == 28 + assert actual['structure_size'].get_value() == 25 + assert actual['flags'].get_value() == 0 + assert actual['security_mode'].get_value() == 1 + assert actual['capabilities'].get_value() == 0 + assert actual['security_buffer_offset'].get_value() == 88 + assert actual['security_buffer_length'].get_value() == 4 + assert actual['previous_session_id'].get_value() == 0 + assert actual['buffer'].get_value() == b"\x01\x02\x03\x04" + + +class TestSMB2SessionSetupResponse(object): + + def test_create_message(self): + message = SMB2SessionSetupResponse() + message['session_flags'] = 1 + message['buffer'] = b"\x04\x03\x02\x01" + expected = b"\x09\x00" \ + b"\x01\x00" \ + b"\x48\x00" \ + b"\x04\x00" \ + b"\x04\x03\x02\x01" + actual = message.pack() + assert len(message) == 12 + assert actual == expected + + def test_parse_message(self): + actual = SMB2SessionSetupResponse() + data = b"\x09\x00" \ + b"\x01\x00" \ + b"\x48\x00" \ + b"\x04\x00" \ + b"\x04\x03\x02\x01" + actual.unpack(data) + assert len(actual) == 12 + assert actual['structure_size'].get_value() == 9 + assert actual['session_flags'].get_value() == 1 + assert actual['security_buffer_offset'].get_value() == 72 + assert actual['security_buffer_length'].get_value() == 4 + assert actual['buffer'].get_value() == b"\x04\x03\x02\x01" + + +class TestSMB2Logoff(object): + + def test_create_message(self): + message = SMB2Logoff() + expected = b"\x04\x00" \ + b"\x00\x00" + actual = message.pack() + assert len(message) == 4 + assert actual == expected + + def test_parse_message(self): + actual = SMB2Logoff() + data = b"\x04\x00" \ + b"\x00\x00" + actual.unpack(data) + assert len(actual) == 4 + assert actual['structure_size'].get_value() == 4 + assert actual['reserved'].get_value() == 0 + + +class TestSession(object): + + def test_dialect_2_0_2(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_0_2) + session = Session(connection, smb_real[0], smb_real[1], + require_encryption=False) + try: + session.connect() + assert len(session.application_key) == 16 + assert session.decryption_key is None + assert not session.encrypt_data + assert session.encryption_key is None + assert len(session.preauth_integrity_hash_value) == 5 + assert not session.require_encryption + assert session.session_id is not None + assert session.session_key == session.application_key + assert session.signing_key == session.signing_key + assert session.signing_required + finally: + connection.disconnect(True) + + def test_dialect_2_1_0(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_1_0) + session = Session(connection, smb_real[0], smb_real[1], + require_encryption=False) + try: + session.connect() + assert len(session.application_key) == 16 + assert session.decryption_key is None + assert not session.encrypt_data + assert session.encryption_key is None + assert len(session.preauth_integrity_hash_value) == 5 + assert not session.require_encryption + assert session.session_id is not None + assert session.session_key == session.application_key + assert session.signing_key == session.signing_key + assert session.signing_required + finally: + connection.disconnect(True) + + def test_dialect_3_0_0(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_0) + session = Session(connection, smb_real[0], smb_real[1]) + try: + session.connect() + assert len(session.application_key) == 16 + assert session.application_key != session.session_key + assert len(session.decryption_key) == 16 + assert session.decryption_key != session.session_key + assert session.encrypt_data + assert len(session.encryption_key) == 16 + assert session.encryption_key != session.session_key + assert len(session.preauth_integrity_hash_value) == 5 + assert session.require_encryption + assert session.session_id is not None + assert len(session.session_key) == 16 + assert len(session.signing_key) == 16 + assert session.signing_key != session.session_key + assert not session.signing_required + finally: + connection.disconnect(True) + + def test_dialect_3_0_2(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + try: + session.connect() + assert len(session.application_key) == 16 + assert session.application_key != session.session_key + assert len(session.decryption_key) == 16 + assert session.decryption_key != session.session_key + assert session.encrypt_data + assert len(session.encryption_key) == 16 + assert session.encryption_key != session.session_key + assert len(session.preauth_integrity_hash_value) == 5 + assert session.require_encryption + assert session.session_id is not None + assert len(session.session_key) == 16 + assert len(session.signing_key) == 16 + assert session.signing_key != session.session_key + assert not session.signing_required + finally: + connection.disconnect(True) + + def test_dialect_3_1_1(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_1_1) + session = Session(connection, smb_real[0], smb_real[1]) + try: + session.connect() + assert len(session.application_key) == 16 + assert session.application_key != session.session_key + assert len(session.decryption_key) == 16 + assert session.decryption_key != session.session_key + assert session.encrypt_data + assert len(session.encryption_key) == 16 + assert session.encryption_key != session.session_key + assert len(session.preauth_integrity_hash_value) == 5 + assert session.require_encryption + assert session.session_id is not None + assert len(session.session_key) == 16 + assert len(session.signing_key) == 16 + assert session.signing_key != session.session_key + assert not session.signing_required + finally: + connection.disconnect(True) + # test that disconnect can be run mutliple times + session.disconnect() + + def test_require_encryption(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect() + session = Session(connection, smb_real[0], smb_real[1], True) + try: + session.connect() + assert len(session.application_key) == 16 + assert session.application_key != session.session_key + assert len(session.decryption_key) == 16 + assert session.decryption_key != session.session_key + assert session.encrypt_data + assert len(session.encryption_key) == 16 + assert session.encryption_key != session.session_key + assert len(session.preauth_integrity_hash_value) == 5 + assert session.require_encryption + assert session.session_id is not None + assert len(session.session_key) == 16 + assert len(session.signing_key) == 16 + assert session.signing_key != session.session_key + assert not session.signing_required + finally: + connection.disconnect(True) + + def test_require_encryption_not_supported(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_1_0) + session = None + try: + session = Session(connection, smb_real[0], smb_real[1]) + with pytest.raises(SMBException) as exc: + session.connect() + assert str(exc.value) == "SMB encryption is required but the " \ + "connection does not support it" + finally: + connection.disconnect(True) + + def test_setup_session_with_ms_gss_token(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect() + connection.gss_negotiate_token = b"\x60\x76\x06\x06\x2b\x06\x01\x05" \ + b"\x05\x02\xa0\x6c\x30\x6a\xa0\x3c" \ + b"\x30\x3a\x06\x0a\x2b\x06\x01\x04" \ + b"\x01\x82\x37\x02\x02\x1e\x06\x09" \ + b"\x2a\x86\x48\x82\xf7\x12\x01\x02" \ + b"\x02\x06\x09\x2a\x86\x48\x86\xf7" \ + b"\x12\x01\x02\x02\x06\x0a\x2a\x86" \ + b"\x48\x86\xf7\x12\x01\x02\x02\x03" \ + b"\x06\x0a\x2b\x06\x01\x04\x01\x82" \ + b"\x37\x02\x02\x0a\xa3\x2a\x30\x28" \ + b"\xa0\x26\x1b\x24\x6e\x6f\x74\x5f" \ + b"\x64\x65\x66\x69\x6e\x65\x64\x5f" \ + b"\x69\x6e\x5f\x52\x46\x43\x34\x31" \ + b"\x37\x38\x40\x70\x6c\x65\x61\x73" \ + b"\x65\x5f\x69\x67\x6e\x6f\x72\x65" + session = Session(connection, smb_real[0], smb_real[1], False) + try: + session.connect() + assert len(session.application_key) == 16 + assert session.application_key != session.session_key + assert len(session.decryption_key) == 16 + assert session.decryption_key != session.session_key + assert not session.encrypt_data + assert len(session.encryption_key) == 16 + assert session.encryption_key != session.session_key + assert len(session.preauth_integrity_hash_value) == 5 + assert not session.require_encryption + assert session.session_id is not None + assert len(session.session_key) == 16 + assert len(session.signing_key) == 16 + assert session.signing_key != session.session_key + assert session.signing_required + finally: + connection.disconnect(True) + + def test_invalid_user(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect() + try: + session = Session(connection, "fakeuser", "fakepass") + with pytest.raises(SMBAuthenticationError) as exc: + session.connect() + assert "Failed to authenticate with server: " in str(exc.value) + finally: + connection.disconnect(True) diff --git a/tests/test_spnego.py b/tests/test_spnego.py new file mode 100644 index 00000000..8b4edc00 --- /dev/null +++ b/tests/test_spnego.py @@ -0,0 +1,39 @@ +from pyasn1.codec.der.decoder import decode +from pyasn1.type.univ import ObjectIdentifier + +from smbprotocol.spnego import InitialContextToken, NegotiateToken, MechTypes + + +class TestSpnego(object): + + def test_parse_initial_context_token(self): + data = b"\x60\x76\x06\x06\x2b\x06\x01\x05" \ + b"\x05\x02\xa0\x6c\x30\x6a\xa0\x3c" \ + b"\x30\x3a\x06\x0a\x2b\x06\x01\x04" \ + b"\x01\x82\x37\x02\x02\x1e\x06\x09" \ + b"\x2a\x86\x48\x82\xf7\x12\x01\x02" \ + b"\x02\x06\x09\x2a\x86\x48\x86\xf7" \ + b"\x12\x01\x02\x02\x06\x0a\x2a\x86" \ + b"\x48\x86\xf7\x12\x01\x02\x02\x03" \ + b"\x06\x0a\x2b\x06\x01\x04\x01\x82" \ + b"\x37\x02\x02\x0a\xa3\x2a\x30\x28" \ + b"\xa0\x26\x1b\x24\x6e\x6f\x74\x5f" \ + b"\x64\x65\x66\x69\x6e\x65\x64\x5f" \ + b"\x69\x6e\x5f\x52\x46\x43\x34\x31" \ + b"\x37\x38\x40\x70\x6c\x65\x61\x73" \ + b"\x65\x5f\x69\x67\x6e\x6f\x72\x65" + actual, rdata = decode(data, asn1Spec=InitialContextToken()) + assert rdata == b"" + assert actual['thisMech'] == ObjectIdentifier('1.3.6.1.5.5.2') + assert isinstance(actual['innerContextToken'], NegotiateToken) + actual_token = actual['innerContextToken']['negTokenInit'] + assert actual_token['mechTypes'] == [ + MechTypes.NEGOEX, + MechTypes.MS_KRB5, + MechTypes.KRB5, + MechTypes.KRB5_U2U, + MechTypes.NTLMSSP + + ] + assert actual_token['negHints']['hintName'] == \ + "not_defined_in_RFC4178@please_ignore" diff --git a/tests/test_structure.py b/tests/test_structure.py new file mode 100644 index 00000000..803000dd --- /dev/null +++ b/tests/test_structure.py @@ -0,0 +1,1449 @@ +import pytest +import types +import uuid + +from datetime import datetime + +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + +from smbprotocol.connection import Capabilities, Commands, Dialects +from smbprotocol.structure import Structure, IntField, BytesField, ListField, \ + UuidField, DateTimeField, StructureField, EnumField, FlagField, \ + BoolField, _bytes_to_hex, InvalidFieldDefinition + + +def test_bytes_to_hex_pretty_newline(): + bytes_str = b"\x00\x01abc123new" + expected = "00 01 61 62 63 31 32 33\n6E 65 77" + actual = _bytes_to_hex(bytes_str, pretty=True) + assert actual == expected + + +def test_bytes_to_hex_pretty_newline_override(): + bytes_str = b"\x00\x01abc123new" + expected = "00 01 61 62\n63 31 32 33\n6E 65 77" + actual = _bytes_to_hex(bytes_str, pretty=True, hex_per_line=4) + assert actual == expected + + +def test_bytes_to_hex_pretty_nonewline(): + bytes_str = b"\x00\x01abc123new" + expected = "00 01 61 62 63 31 32 33 6E 65 77" + actual = _bytes_to_hex(bytes_str, pretty=True, hex_per_line=0) + assert actual == expected + + +def test_bytes_to_hex_not_pretty(): + bytes_str = b"\x00\x01abc123new" + expected = "00016162633132336e6577" + actual = _bytes_to_hex(bytes_str, pretty=False) + assert actual == expected + + +class Structure2(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', IntField( + size=4, + default=125, + )), + ('bytes', BytesField( + size=4, + default=b"\x10\x11\x12\x13", + )), + ]) + super(Structure2, self).__init__() + + +class Structure1(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('int_field', IntField(size=4)), + ('bytes_field', BytesField(size=2)), + ('var_field', BytesField( + size=lambda s: s['int_field'].get_value(), + )), + ('default_field', IntField( + size=2, + default=b"\x01a", + )), + ('list_field', ListField( + list_count=lambda s: s['int_field'].get_value(), + list_type=BytesField(size=8), + size=lambda s: s['int_field'].get_value() * 8, + )), + ('structure_length', IntField( + size=2, + little_endian=False, + default=lambda s: len(s['structure_field']), + )), + ('structure_field', StructureField( + size=lambda s: s['structure_length'].get_value(), + structure_type=Structure2, + )), + ]) + + super(Structure1, self).__init__() + + +class TestStructure(object): + + def test_structure_defaults(self): + actual = Structure1() + assert len(actual.fields) == 7 + assert actual['int_field'].get_value() == 0 + assert actual['bytes_field'].get_value() == b"" + assert actual['var_field'].get_value() == b"" + assert actual['default_field'].get_value() == 24833 + assert actual['list_field'].get_value() == [] + assert actual['structure_length'].get_value() == 0 + assert actual['structure_field'].get_value() == b"" + + def test_get_field(self): + structure = Structure1() + actual = structure['default_field'] + assert actual.name == "default_field" + assert actual.size == 2 + assert actual.get_value() == 24833 + + def test_set_field(self): + structure = Structure1() + assert structure['int_field'].get_value() == 0 + structure['int_field'] = 10 + assert structure['int_field'].get_value() == 10 + + def test_remove_field(self): + structure = Structure1() + assert len(structure.fields) == 7 + del structure['int_field'] + assert len(structure.fields) == 6 + with pytest.raises(ValueError) as exc: + value = structure['int_field'] + assert str(exc.value) == "Structure does not contain field int_field" + + def test_pack_structure(self): + structure = Structure1() + sub_structure = Structure2() + structure['int_field'] = 3 + structure['bytes_field'] = b"\x01\x02" + structure['var_field'] = b"\x03\x04\x05" + structure['list_field'] = [ + b"\x31\x00\x32\x00\x33\x00\x34\x00", + b"1\x002\x003\x004\00", + sub_structure, + ] + structure['structure_field'] = sub_structure + + expected = b"\x03\x00\x00\x00" \ + b"\x01\x02" \ + b"\x03\x04\x05" \ + b"\x01\x61" \ + b"\x31\x00\x32\x00\x33\x00\x34\x00" \ + b"\x31\x00\x32\x00\x33\x00\x34\x00" \ + b"\x7d\x00\x00\x00\x10\x11\x12\x13" \ + b"\x00\x08" \ + b"\x7d\x00\x00\x00\x10\x11\x12\x13" + actual = structure.pack() + assert actual == expected + assert len(structure) == len(actual) + + def test_unpack_structure(self): + packed_data = b"\x03\x00\x00\x00" \ + b"\x01\x02" \ + b"\x03\x04\x05" \ + b"\x01\x61" \ + b"\x31\x00\x32\x00\x33\x00\x34\x00" \ + b"\x31\x00\x32\x00\x33\x00\x34\x00" \ + b"\x7d\x00\x00\x00\x10\x11\x12\x13" \ + b"\x00\x08" \ + b"\x7d\x00\x00\x00\x10\x11\x12\x13" + + actual = Structure1() + actual.unpack(packed_data) + assert actual['int_field'].get_value() == 3 + assert actual['bytes_field'].get_value() == b"\x01\x02" + assert actual['var_field'].get_value() == b"\x03\x04\x05" + assert actual['default_field'].get_value() == 24833 + assert actual['list_field'].get_value() == [ + b"\x31\x00\x32\x00\x33\x00\x34\x00", + b"\x31\x00\x32\x00\x33\x00\x34\x00", + b"\x7d\x00\x00\x00\x10\x11\x12\x13" + ] + assert actual['structure_length'].get_value() == 8 + expected_struct = Structure2().pack() + assert actual['structure_field'].get_value().pack() == expected_struct + assert len(actual) == len(packed_data) + + def test_structure_string(self): + structure = Structure1() + sub_structure = Structure2() + structure['int_field'] = 3 + structure['bytes_field'] = b"\x01\x02" + structure['var_field'] = b"\x03\x04\x05" + structure['list_field'] = [ + b"\x31\x00\x32\x00\x33\x00\x34\x00", + b"1\x002\x003\x004\x00", + sub_structure, + ] + structure['structure_field'] = sub_structure + + expected = """Structure1: + int_field = 3 + bytes_field = 01 02 + var_field = 03 04 05 + default_field = 24833 + list_field = [ + 31 00 32 00 33 00 34 00, + 31 00 32 00 33 00 34 00, + Structure2: + field = 125 + bytes = 10 11 12 13 + + Raw Hex: + 7D 00 00 00 10 11 12 13 + ] + structure_length = 8 + structure_field = + Structure2: + field = 125 + bytes = 10 11 12 13 + + Raw Hex: + 7D 00 00 00 10 11 12 13 + + Raw Hex: + 03 00 00 00 01 02 03 04 + 05 01 61 31 00 32 00 33 + 00 34 00 31 00 32 00 33 + 00 34 00 7D 00 00 00 10 + 11 12 13 00 08 7D 00 00 + 00 10 11 12 13""" + actual = str(structure) + assert actual == expected + + def test_end_field_no_size(self): + class Structure3(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', IntField(size=2, default=1)), + ('end', BytesField()), + ]) + super(Structure3, self).__init__() + + structure = Structure3() + structure['end'] = b"\x01\x02\x03\x04" + expected_pack = b"\x01\x00\x01\x02\x03\x04" + actual_pack = structure.pack() + assert actual_pack == expected_pack + assert len(structure['end']) == 4 + + structure.unpack(b"\x02\x00\x05\x06\x07\x08\x09\x10") + assert structure['field'].get_value() == 2 + assert structure['end'].get_value() == b"\x05\x06\x07\x08\x09\x10" + assert len(structure['end']) == 6 + + +class TestIntField(object): + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', IntField(size=4, default=1234)) + ]) + super(TestIntField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 4 + actual = len(field) + assert actual == expected + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = "1234" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = 1234 + actual = field.get_value() + assert actual == expected + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\xd2\x04\x00\x00" + actual = field.pack() + assert actual == expected + + def test_pack_with_lambda_size(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.size = lambda s: 2 + field.set_value(4) + expected = b"\x04\x00" + actual = field.pack() + assert actual == expected + + def test_unpack(self): + field = self.StructureTest()['field'] + field.unpack(b"\xd2\x05\x00\x00") + expected = 1490 + actual = field.get_value() + assert actual == expected + + def test_invalid_size_none(self): + with pytest.raises(InvalidFieldDefinition) as exc: + IntField(size=None) + assert str(exc.value) == "IntField size must have a value of 1, 2, " \ + "4, or 8 not None" + + def test_invalid_size_bad_int(self): + with pytest.raises(InvalidFieldDefinition) as exc: + IntField(size=3) + assert str(exc.value) == "IntField size must have a value of 1, 2, " \ + "4, or 8 not 3" + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = 0 + actual = field.get_value() + assert isinstance(field.value, int) + assert actual == expected + + def test_set_lambda(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.set_value(lambda s: 4567) + expected = 4567 + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 4 + + def test_set_bytes(self): + field = self.StructureTest()['field'] + field.set_value(b"\x12\x34\x00\x00") + expected = 13330 + actual = field.get_value() + assert isinstance(field.value, int) + assert actual == expected + + def test_set_int(self): + field = self.StructureTest()['field'] + field.set_value(9876) + expected = 9876 + actual = field.get_value() + assert isinstance(field.value, int) + assert actual == expected + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value([]) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type list to an int" + + def test_byte_order(self): + class ByteOrderStructure(Structure): + def __init__(self): + self.fields = OrderedDict([( + 'field', IntField(size=2, little_endian=False, default=10) + )]) + super(ByteOrderStructure, self).__init__() + + field = ByteOrderStructure()['field'] + expected = b"\x00\x0a" + actual = field.pack() + assert actual == expected + + +class TestBytesField(object): + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', BytesField(size=4, default=b"\x10\x11\x12\x13")) + ]) + super(TestBytesField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 4 + actual = len(field) + assert actual == expected + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = "10 11 12 13" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = b"\x10\x11\x12\x13" + actual = field.get_value() + assert actual == expected + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\x10\x11\x12\x13" + actual = field.pack() + assert actual == expected + + def test_unpack(self): + field = self.StructureTest()['field'] + field.unpack(b"\x7a\x00\x79\x00") + expected = b"\x7a\x00\x79\x00" + actual = field.get_value() + assert actual == expected + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = b"" + actual = field.get_value() + assert isinstance(field.value, bytes) + assert actual == expected + + def test_set_lambda(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.set_value(lambda s: b"\x10\x11\x12\x13") + expected = b"\x10\x11\x12\x13" + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 4 + + def test_set_bytes(self): + field = self.StructureTest()['field'] + field.set_value(b"\x78\x00\x77\x00") + expected = b"\x78\x00\x77\x00" + actual = field.get_value() + assert isinstance(field.value, bytes) + assert actual == expected + + def test_set_int(self): + field = self.StructureTest()['field'] + field.set_value(11) + expected = b"\x0b\x00\x00\x00" + actual = field.get_value() + assert isinstance(field.value, bytes) + assert actual == expected + + def test_set_structure(self): + field = self.StructureTest()['field'] + field.size = 8 + field.set_value(Structure2()) + expected = b"\x7d\x00\x00\x00\x10\x11\x12\x13" + actual = field.get_value() + assert isinstance(field.value, bytes) + assert actual == expected + assert len(field) == 8 + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value([]) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type list to a byte string" + + def test_pack_invalid_size(self): + field = self.StructureTest()['field'] + field.name = "field" + field.set_value(b"\x01\x02") + assert len(field) == 2 + with pytest.raises(ValueError) as exc: + field.pack() + assert str(exc.value) == "Invalid packed data length for field " \ + "field of 2 does not fit field size of 4" + + def test_set_int_invalid_size(self): + class InvalidSizeStructure(Structure): + def __init__(self): + self.fields = OrderedDict([( + 'field', BytesField(size=3) + )]) + super(InvalidSizeStructure, self).__init__() + + with pytest.raises(InvalidFieldDefinition) as exc: + field = InvalidSizeStructure()['field'] + field.set_value(1) + assert str(exc.value) == "Cannot struct format of size 3" + + def test_set_invalid_size(self): + class InvalidSizeStructure(Structure): + def __init__(self): + self.fields = OrderedDict([( + 'field', BytesField(size="a") + )]) + super(InvalidSizeStructure, self).__init__() + + with pytest.raises(InvalidFieldDefinition) as exc: + InvalidSizeStructure() + assert str(exc.value) == "BytesField size for field must be an int " \ + "or None for a variable length" + + +class TestListField(object): + # unpack variable length list + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', ListField( + size=4, + list_count=2, + list_type=BytesField(size=2), + default=[b"\x01\x02", b"\x03\x04"] + )) + ]) + super(TestListField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 4 + actual = len(field) + assert actual == expected + + def test_get_item(self): + field = self.StructureTest()['field'] + assert field[0] == b"\x01\x02" + assert field[1] == b"\x03\x04" + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = "[\n 01 02,\n 03 04\n]" + actual = str(field) + assert actual == expected + + def test_to_string_empty(self): + field = self.StructureTest()['field'] + field.set_value([]) + expected = "[]" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = [b"\x01\x02", b"\x03\x04"] + actual = field.get_value() + assert actual == expected + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\x01\x02\x03\x04" + actual = field.pack() + assert actual == expected + + def test_unpack(self): + field = self.StructureTest()['field'] + data = field.unpack(b"\x7a\x00\x79\x00") + expected = [b"\x7a\x00", b"\x79\x00"] + actual = field.get_value() + assert actual == expected + + def test_unpack_func(self): + class UnpackListStructure(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', ListField( + size=7, + unpack_func=lambda s, d: [ + b"\x01\x02", + b"\x03\x04\x05\x06", + b"\07" + ] + )) + + ]) + super(UnpackListStructure, self).__init__() + + field = UnpackListStructure()['field'] + field.unpack(b"\x00") + expected = [ + b"\x01\x02", + b"\x03\x04\x05\x06", + b"\07" + ] + actual = field.get_value() + assert len(field) == 7 + assert actual == expected + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = [] + actual = field.get_value() + assert isinstance(field.value, list) + assert actual == expected + assert len(field) == 0 + assert len(field.get_value()) == 0 + + def test_set_lambda_as_bytes(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.set_value(lambda s: b"\x10\x11\x12\x13") + expected = [b"\x10\x11", b"\x12\x13"] + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 4 + + def test_set_lambda_as_list(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.set_value(lambda s: [b"\x10\x11", b"\x12\x13"]) + expected = [b"\x10\x11", b"\x12\x13"] + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 4 + + def test_set_bytes_fixed(self): + field = self.StructureTest()['field'] + field.set_value(b"\x78\x00\x77\x00") + expected = [b"\x78\x00", b"\x77\x00"] + actual = field.get_value() + assert isinstance(field.value, list) + assert actual == expected + + def test_set_list(self): + field = self.StructureTest()['field'] + field.set_value([b"\x7d\x00", b"\x00\x00"]) + expected = [b"\x7d\x00", b"\x00\x00"] + actual = field.get_value() + assert isinstance(field.value, list) + assert actual == expected + assert len(field) == 4 + assert len(field.get_value()) == 2 + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value(0) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type int to a list" + + def test_list_count_not_int_or_lambda(self): + class InvalidListField(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', ListField(list_count="a")) + ]) + super(InvalidListField, self).__init__() + with pytest.raises(InvalidFieldDefinition) as exc: + InvalidListField() + assert str(exc.value) == "ListField list_count must be an int, " \ + "lambda, or None for a variable list length" + + def test_unpack_func_not_lambda(self): + class InvalidListField(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', ListField(unpack_func="a")) + ]) + super(InvalidListField, self).__init__() + with pytest.raises(InvalidFieldDefinition) as exc: + InvalidListField() + assert str(exc.value) == "ListField unpack_func must be a lambda " \ + "function or None" + + def test_list_field_not_field(self): + class InvalidListField(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', ListField(list_type="a")) + ]) + super(InvalidListField, self).__init__() + with pytest.raises(InvalidFieldDefinition) as exc: + InvalidListField() + assert str(exc.value) == "ListField list_type must be a Field " \ + "definition" + + def test_list_unpack_list_type_size_not_defined(self): + class InvalidListField(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', ListField(list_count=1)) + ]) + super(InvalidListField, self).__init__() + with pytest.raises(InvalidFieldDefinition) as exc: + InvalidListField() + assert str(exc.value) == "ListField must either define unpack_func " \ + "as a lambda or set list_count and " \ + "list_size with a size" + + def test_list_unpack_list_count_not_defined(self): + class InvalidListField(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', ListField(list_type=BytesField(size=1))) + ]) + super(InvalidListField, self).__init__() + with pytest.raises(InvalidFieldDefinition) as exc: + InvalidListField() + assert str(exc.value) == "ListField must either define unpack_func " \ + "as a lambda or set list_count and " \ + "list_size with a size" + + +class TestStructureField(object): + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', StructureField( + size=8, + structure_type=Structure2, + default=b"\x7d\x00\x00\x00\x10\x11\x12\x13" + )) + ]) + super(TestStructureField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 8 + actual = len(field) + assert actual == expected + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = """Structure2: + field = 125 + bytes = 10 11 12 13 + + Raw Hex: + 7D 00 00 00 10 11 12 13""" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = Structure2() + actual = field.get_value() + assert actual.pack() == expected.pack() + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\x7d\x00\x00\x00\x10\x11\x12\x13" + actual = field.pack() + assert actual == expected + + def test_pack_without_type(self): + field = self.StructureTest()['field'] + field.structure_type = None + + test_value = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + field.set_value(test_value) + actual = field.pack() + assert actual == test_value + + def test_unpack(self): + field = self.StructureTest()['field'] + field.unpack(b"\x7d\x00\x00\x00\x14\x15\x16\x17") + expected = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + actual = field.get_value() + assert actual.pack() == expected + assert isinstance(actual, Structure2) + + def test_unpack_without_type(self): + field = self.StructureTest()['field'] + field.structure_type = None + field.unpack(b"\x7d\x00\x00\x00\x14\x15\x16\x17") + expected = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + actual = field.get_value() + assert actual == expected + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = b"" + actual = field.get_value() + assert isinstance(field.value, bytes) + assert actual == expected + + def test_set_empty_byte(self): + field = self.StructureTest()['field'] + field.set_value(b"") + expected = b"" + actual = field.get_value() + assert isinstance(field.value, bytes) + assert actual == expected + + def test_set_lambda(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.set_value(lambda s: b"\x7d\x00\x00\x00\x14\x15\x16\x17") + expected = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual.pack() == expected + assert isinstance(actual, Structure2) + assert len(field) == 8 + + def test_set_lambda_without_type(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.structure_type = None + field.set_value(lambda s: b"\x7d\x00\x00\x00\x14\x15\x16\x17") + expected = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert isinstance(actual, bytes) + assert len(field) == 8 + + def test_set_bytes(self): + field = self.StructureTest()['field'] + field.set_value(b"\x7d\x00\x00\x00\x14\x15\x16\x17") + expected = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + actual = field.get_value() + assert isinstance(field.value, Structure2) + assert actual.pack() == expected + + def test_set_bytes_without_type(self): + field = self.StructureTest()['field'] + field.structure_type = None + field.set_value(b"\x7d\x00\x00\x00\x14\x15\x16\x17") + expected = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + actual = field.get_value() + assert isinstance(field.value, bytes) + assert actual == expected + + def test_set_bytes_then_structure_type(self): + field = self.StructureTest()['field'] + field.structure_type = None + field.set_value(b"\x7d\x00\x00\x00\x14\x15\x16\x17") + expected = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + actual = field.get_value() + assert isinstance(field.value, bytes) + assert actual == expected + field.set_structure_type(Structure2) + + actual = field.get_value() + assert isinstance(field.value, Structure2) + assert actual.pack() == expected + + def test_set_bytes_with_lambda_type(self): + field = self.StructureTest()['field'] + field.structure_type = lambda s: Structure2 + field.set_value(b"\x7d\x00\x00\x00\x14\x15\x16\x17") + expected = b"\x7d\x00\x00\x00\x14\x15\x16\x17" + actual = field.get_value() + assert isinstance(field.value, Structure2) + assert actual.pack() == expected + + def test_set_structure(self): + field = self.StructureTest()['field'] + expected = b"\x7d\x00\x00\x00\x10\x11\x12\x13" + actual = field.get_value() + assert isinstance(field.value, Structure) + assert actual.pack() == expected + assert len(field) == 8 + + def test_get_structure_field(self): + field = self.StructureTest()['field'] + expected = 125 + actual = field['field'].get_value() + assert actual == expected + + def test_fail_get_structure_field_missing(self): + field = self.StructureTest()['field'] + with pytest.raises(ValueError) as exc: + field['fake'] + assert str(exc.value) == "Structure does not contain field fake" + + def test_fail_get_structure_bytes_value(self): + field = self.StructureTest()['field'] + field.structure_type = None + field.set_value(b"\x7d\x00\x00\x00\x14\x15\x16\x17") + with pytest.raises(ValueError) as exc: + field['field'] + assert str(exc.value) == "Cannot get field field when structure is " \ + "defined as a byte string" + + def test_set_structure_field(self): + field = self.StructureTest()['field'] + test_value = 100 + field['field'] = test_value + actual = field['field'].get_value() + assert actual == test_value + # test out the normal path (convoluted way) + assert field.get_value()['field'].get_value() == test_value + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value([]) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type list to a structure" + + +class TestUuidField(object): + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', UuidField()) + ]) + super(TestUuidField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 16 + actual = len(field) + assert actual == expected + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = "00000000-0000-0000-0000-000000000000" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = uuid.UUID("00000000-0000-0000-0000-000000000000") + actual = field.get_value() + assert actual == expected + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\x00" * 16 + actual = field.pack() + assert actual == expected + + def test_unpack(self): + field = self.StructureTest()['field'] + field.unpack(b"\x11" * 16) + expected = uuid.UUID(bytes=b"\x11" * 16) + actual = field.get_value() + assert actual == expected + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = uuid.UUID("00000000-0000-0000-0000-000000000000") + actual = field.get_value() + assert isinstance(field.value, uuid.UUID) + assert actual == expected + + def test_set_lambda(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.set_value(lambda s: uuid.UUID(bytes=b"\x11" * 16)) + expected = uuid.UUID(bytes=b"\x11" * 16) + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 16 + + def test_set_bytes(self): + field = self.StructureTest()['field'] + field.set_value(b"\x22" * 16) + expected = uuid.UUID(bytes=b"\x22" * 16) + actual = field.get_value() + assert isinstance(field.value, uuid.UUID) + assert actual == expected + + def test_set_int(self): + field = self.StructureTest()['field'] + field.set_value(45370982256125128461783280990902428194) + expected = uuid.UUID(int=45370982256125128461783280990902428194) + actual = field.get_value() + assert isinstance(field.value, uuid.UUID) + assert actual == expected + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value([]) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type list to a uuid" + + def test_invalid_size_none(self): + with pytest.raises(InvalidFieldDefinition) as exc: + UuidField(size=8) + assert str(exc.value) == "UuidField type must have a size of 16 not 8" + + def test_pack_uuid_field_big_endian(self): + field = self.StructureTest()['field'] + field.little_endian = False + field.set_value(uuid.UUID("00000001-0001-0001-0001-000000000001")) + expected = b"\x01\x00\x00\x00\x01\x00\x01\x00" \ + b"\x00\x01\x00\x00\x00\x00\x00\x01" + actual = field.pack() + assert actual == expected + + def test_unpack_uuid_field_big_endian(self): + field = self.StructureTest()['field'] + field.little_endian = False + field.unpack(b"\x01\x00\x00\x00\x01\x00\x01\x00" + b"\x00\x01\x00\x00\x00\x00\x00\x01") + expected = uuid.UUID("00000001-0001-0001-0001-000000000001") + actual = field.get_value() + assert actual == expected + + +class TestDateTimeField(object): + + DATE = datetime(year=1993, month=6, day=11, hour=7, minute=52, + second=34, microsecond=34) + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', DateTimeField( + default=TestDateTimeField.DATE, + )) + ]) + super(TestDateTimeField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 8 + actual = len(field) + assert actual == expected + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = "1993-06-11 07:52:34.000034" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = self.DATE + actual = field.get_value() + assert actual == expected + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\x54\x0e\x63\x5e\x2d\xfa\xb7\x01" + actual = field.pack() + assert actual == expected + + def test_unpack(self): + field = self.StructureTest()['field'] + field.unpack(b"\x5e\x70\x27\x4a\x6e\x23\x93\x01") + expected = datetime(year=1960, month=8, day=1, hour=22, minute=7, + second=1, microsecond=186774) + actual = field.get_value() + assert actual == expected + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = datetime.today() + actual = field.get_value() + assert isinstance(field.value, datetime) + assert actual.year == expected.year + assert actual.month == expected.month + assert actual.day == expected.day + + def test_set_lambda(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.set_value(lambda s: datetime(year=1960, month=8, day=2, hour=8, + minute=7, second=1, + microsecond=186774)) + expected = datetime(year=1960, month=8, day=2, hour=8, minute=7, + second=1, microsecond=186774) + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 8 + + def test_set_bytes(self): + field = self.StructureTest()['field'] + field.set_value(b"\x00\x67\x7b\x21\x3d\x5d\xd3\x01") + expected = datetime(year=2017, month=11, day=14, hour=11, minute=38, + second=46) + actual = field.get_value() + assert isinstance(field.value, datetime) + assert actual == expected + + def test_set_int(self): + field = self.StructureTest()['field'] + field.set_value(131551331260000000) + expected = datetime(year=2017, month=11, day=14, hour=11, minute=38, + second=46) + actual = field.get_value() + assert isinstance(field.value, datetime) + assert actual == expected + + def test_set_datetime(self): + field = self.StructureTest()['field'] + datetime_value = datetime(year=2017, month=11, day=14, hour=21, + minute=38, second=46) + field.set_value(datetime_value) + actual = field.get_value() + assert isinstance(field.value, datetime) + assert actual == datetime_value + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value([]) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type list to a datetime" + + def test_invalid_size_none(self): + with pytest.raises(InvalidFieldDefinition) as exc: + DateTimeField(size=4) + assert str(exc.value) == "DateTimeField type must have a size of 8 " \ + "not 4" + + +class TestEnumField(object): + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', EnumField( + size=1, + enum_type=Commands, + default=Commands.SMB2_IOCTL, + )), + ]) + super(TestEnumField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 1 + actual = len(field) + assert actual == expected + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = "(11) SMB2_IOCTL" + actual = str(field) + assert actual == expected + + def test_to_string_default_as_zero(self): + class StructureTestDefaultZero(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', EnumField( + size=2, + enum_type=Dialects, + )) + ]) + super(StructureTestDefaultZero, self).__init__() + field = StructureTestDefaultZero()['field'] + expected = "(0) UNKNOWN_ENUM" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = 11 + actual = field.get_value() + assert actual == expected + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\x0b" + actual = field.pack() + assert actual == expected + + def test_unpack(self): + field = self.StructureTest()['field'] + field.unpack(b"\x0b") + expected = 11 + actual = field.get_value() + assert actual == expected + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = 0 + actual = field.get_value() + assert actual == expected + assert isinstance(field.value, int) + assert str(field) == "(0) SMB2_NEGOTIATE" + + def test_set_bytes(self): + field = self.StructureTest()['field'] + field.set_value(b"\x08") + expected = 8 + actual = field.get_value() + assert actual == expected + assert isinstance(field.value, int) + assert str(field) == "(8) SMB2_READ" + + def test_set_int(self): + field = self.StructureTest()['field'] + field.set_value(8) + expected = 8 + actual = field.get_value() + assert actual == expected + assert isinstance(field.value, int) + assert str(field) == "(8) SMB2_READ" + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value([]) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type list to an int" + + def test_set_invalid_value(self): + field = self.StructureTest()['field'] + with pytest.raises(ValueError) as exc: + field.set_value(0x13) + assert str(exc.value) == "Enum value 19 does not exist in enum type " \ + "" + + +class TestFlagField(object): + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', FlagField( + size=4, + flag_type=Capabilities, + default=Capabilities.SMB2_GLOBAL_CAP_LEASING | + Capabilities.SMB2_GLOBAL_CAP_ENCRYPTION + )), + ]) + super(TestFlagField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 4 + actual = len(field) + assert actual == expected + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = "(66) SMB2_GLOBAL_CAP_ENCRYPTION, SMB2_GLOBAL_CAP_LEASING" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = 66 + actual = field.get_value() + assert actual == expected + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\x42\x00\x00\x00" + actual = field.pack() + assert actual == expected + + def test_unpack(self): + field = self.StructureTest()['field'] + field.unpack(b"\x4a\x00\x00\x00") + expected = 74 + actual = field.get_value() + assert actual == expected + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = 0 + actual = field.get_value() + assert actual == expected + assert isinstance(field.value, int) + assert str(field) == "0" + + def test_set_bytes(self): + field = self.StructureTest()['field'] + field.set_value(b"\x08\x00\x00\x00") + expected = 8 + actual = field.get_value() + assert actual == expected + assert isinstance(field.value, int) + assert str(field) == "(8) SMB2_GLOBAL_CAP_MULTI_CHANNEL" + + def test_set_int(self): + field = self.StructureTest()['field'] + field.set_value(8) + expected = 8 + actual = field.get_value() + assert actual == expected + assert isinstance(field.value, int) + assert str(field) == "(8) SMB2_GLOBAL_CAP_MULTI_CHANNEL" + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value([]) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type list to an int" + + def test_check_flag_set(self): + field = self.StructureTest()['field'] + assert field.has_flag(Capabilities.SMB2_GLOBAL_CAP_ENCRYPTION) + assert not field.has_flag(Capabilities.SMB2_GLOBAL_CAP_MULTI_CHANNEL) + + def test_set_flag(self): + field = self.StructureTest()['field'] + assert not field.has_flag(Capabilities.SMB2_GLOBAL_CAP_MULTI_CHANNEL) + field.set_flag(Capabilities.SMB2_GLOBAL_CAP_MULTI_CHANNEL) + assert field.has_flag(Capabilities.SMB2_GLOBAL_CAP_MULTI_CHANNEL) + + def test_set_invalid_flag(self): + field = self.StructureTest()['field'] + with pytest.raises(ValueError) as ex: + field.set_flag(10) + assert str(ex.value) == "Flag value does not exist in flag type " \ + "" + + def test_set_invalid_value(self): + field = self.StructureTest()['field'] + with pytest.raises(ValueError) as exc: + field.set_value(0x00000082) + assert str(exc.value) == "Invalid flag for field field value set 128" + + +class TestBoolField(object): + + class StructureTest(Structure): + def __init__(self): + self.fields = OrderedDict([ + ('field', BoolField(size=1)) + ]) + super(TestBoolField.StructureTest, self).__init__() + + def test_get_size(self): + field = self.StructureTest()['field'] + expected = 1 + actual = len(field) + assert actual == expected + + def test_to_string(self): + field = self.StructureTest()['field'] + expected = "False" + actual = str(field) + assert actual == expected + + def test_to_string_true(self): + field = self.StructureTest()['field'] + field.set_value(True) + expected = "True" + actual = str(field) + assert actual == expected + + def test_get_value(self): + field = self.StructureTest()['field'] + expected = False + actual = field.get_value() + assert actual == expected + + def test_get_value_true(self): + field = self.StructureTest()['field'] + field.set_value(True) + expected = True + actual = field.get_value() + assert actual == expected + + def test_pack(self): + field = self.StructureTest()['field'] + expected = b"\x00" + actual = field.pack() + assert actual == expected + + def test_pack_true(self): + field = self.StructureTest()['field'] + field.set_value(True) + expected = b"\x01" + actual = field.pack() + assert actual == expected + + def test_unpack(self): + field = self.StructureTest()['field'] + field.unpack(b"\x00") + expected = False + actual = field.get_value() + assert actual == expected + + def test_unpack_true(self): + field = self.StructureTest()['field'] + field.unpack(b"\x01") + expected = True + actual = field.get_value() + assert actual == expected + + def test_invalid_size_bad_int(self): + with pytest.raises(InvalidFieldDefinition) as exc: + BoolField(size=2) + assert str(exc.value) == "BoolField size must have a value of 1, not 2" + + def test_set_none(self): + field = self.StructureTest()['field'] + field.set_value(None) + expected = False + actual = field.get_value() + assert isinstance(field.value, bool) + assert actual == expected + + def test_set_bytes(self): + field = self.StructureTest()['field'] + field.set_value(b"\x01") + expected = True + actual = field.get_value() + assert isinstance(field.value, bool) + assert actual == expected + + def test_set_bool(self): + field = self.StructureTest()['field'] + field.set_value(True) + expected = True + actual = field.get_value() + assert isinstance(field.value, bool) + assert actual == expected + + def test_set_lambda(self): + structure = self.StructureTest() + field = structure['field'] + field.name = "field" + field.structure = self.StructureTest + field.set_value(lambda s: True) + expected = True + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 1 + + def test_set_invalid(self): + field = self.StructureTest()['field'] + field.name = "field" + with pytest.raises(TypeError) as exc: + field.set_value([]) + assert str(exc.value) == "Cannot parse value for field field of " \ + "type list to a bool" diff --git a/tests/test_transport.py b/tests/test_transport.py new file mode 100644 index 00000000..661b3211 --- /dev/null +++ b/tests/test_transport.py @@ -0,0 +1,41 @@ +import pytest + +from smbprotocol.transport import DirectTCPPacket, Tcp + + +class TestDirectTcpPacket(object): + + def test_create_message(self): + message = DirectTCPPacket() + message['smb2_message'] = b"\xfe\x53\x4d\x42" + expected = b"\x00\x00\x00\x04" \ + b"\xfe\x53\x4d\x42" + + actual = message.pack() + assert len(message) == 8 + assert message['stream_protocol_length'].get_value() == 4 + assert actual == expected + + def test_parse_message(self): + actual = DirectTCPPacket() + data = b"\x00\x00\x00\x04" \ + b"\xfe\x53\x4d\x42" + actual.unpack(data) + assert len(actual) == 8 + assert actual['stream_protocol_length'].get_value() == 4 + assert isinstance(actual['smb2_message'].get_value(), bytes) + + actual_header = actual['smb2_message'] + assert len(actual_header) == 4 + assert actual_header.get_value() == b"\xfe\x53\x4d\x42" + + +class TestTcp(object): + + def test_normal_fail_message_too_big(self): + tcp = Tcp("0.0.0.0", 0) + with pytest.raises(ValueError) as exc: + tcp.send(b"\x00" * 16777216) + assert str(exc.value) == "Data to be sent over Direct TCP size " \ + "16777216 exceeds the max length allowed " \ + "16777215" diff --git a/tests/test_tree.py b/tests/test_tree.py new file mode 100644 index 00000000..141c664e --- /dev/null +++ b/tests/test_tree.py @@ -0,0 +1,252 @@ +import uuid + +import pytest + +from smbprotocol.connection import Connection, Dialects +from smbprotocol.exceptions import SMBException, SMBResponseException +from smbprotocol.session import Session +from smbprotocol.tree import SMB2TreeConnectRequest, SMB2TreeConnectResponse, \ + SMB2TreeDisconnect, TreeConnect + +from .utils import smb_real + + +class TestSMB2TreeConnectRequest(object): + + def test_create_message(self): + message = SMB2TreeConnectRequest() + message['flags'] = 2 + message['buffer'] = "\\\\127.0.0.1\\c$".encode("utf-16-le") + expected = b"\x09\x00" \ + b"\x02\x00" \ + b"\x48\x00" \ + b"\x1c\x00" \ + b"\x5c\x00\x5c\x00\x31\x00\x32\x00" \ + b"\x37\x00\x2e\x00\x30\x00\x2e\x00" \ + b"\x30\x00\x2e\x00\x31\x00\x5c\x00" \ + b"\x63\x00\x24\x00" + actual = message.pack() + assert len(message) == 36 + assert actual == expected + + def test_parse_message(self): + actual = SMB2TreeConnectRequest() + data = b"\x09\x00" \ + b"\x02\x00" \ + b"\x48\x00" \ + b"\x1c\x00" \ + b"\x5c\x00\x5c\x00\x31\x00\x32\x00" \ + b"\x37\x00\x2e\x00\x30\x00\x2e\x00" \ + b"\x30\x00\x2e\x00\x31\x00\x5c\x00" \ + b"\x63\x00\x24\x00" + actual.unpack(data) + assert len(actual) == 36 + assert actual['structure_size'].get_value() == 9 + assert actual['flags'].get_value() == 2 + assert actual['path_offset'].get_value() == 72 + assert actual['path_length'].get_value() == 28 + assert actual['buffer'].get_value() == "\\\\127.0.0.1\\c$"\ + .encode("utf-16-le") + + +class TestSMB2TreeConnectResponse(object): + + def test_create_message(self): + message = SMB2TreeConnectResponse() + message['share_type'] = 1 + message['share_flags'] = 2 + message['capabilities'] = 8 + message['maximal_access'] = 10 + expected = b"\x10\x00" \ + b"\x01" \ + b"\x00" \ + b"\x02\x00\x00\x00" \ + b"\x08\x00\x00\x00" \ + b"\x0a\x00\x00\x00" + actual = message.pack() + assert len(message) == 16 + assert actual == expected + + def test_parse_message(self): + actual = SMB2TreeConnectResponse() + data = b"\x10\x00" \ + b"\x01" \ + b"\x00" \ + b"\x02\x00\x00\x00" \ + b"\x08\x00\x00\x00" \ + b"\x0a\x00\x00\x00" + actual.unpack(data) + assert len(actual) == 16 + assert actual['structure_size'].get_value() == 16 + assert actual['share_type'].get_value() == 1 + assert actual['reserved'].get_value() == 0 + assert actual['share_flags'].get_value() == 2 + assert actual['capabilities'].get_value() == 8 + assert actual['maximal_access'].get_value() == 10 + + +class TestSMB2TreeDisconnect(object): + + def test_create_message(self): + message = SMB2TreeDisconnect() + expected = b"\x04\x00" \ + b"\x00\x00" + actual = message.pack() + assert len(message) == 4 + assert actual == expected + + def test_parse_message(self): + actual = SMB2TreeDisconnect() + data = b"\x04\x00" \ + b"\x00\x00" + actual.unpack(data) + assert len(actual) == 4 + assert actual['structure_size'].get_value() == 4 + assert actual['reserved'].get_value() == 0 + + +class TestTreeConnect(object): + + def test_dialect_2_0_2(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_0_2) + session = Session(connection, smb_real[0], smb_real[1], False) + tree = TreeConnect(session, smb_real[4]) + try: + session.connect() + tree.connect() + assert tree.encrypt_data is None + assert not tree.is_ca_share + assert not tree.is_dfs_share + assert not tree.is_scaleout_share + assert isinstance(tree.tree_connect_id, int) + finally: + connection.disconnect(True) + + def test_dialect_2_1_0(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_1_0) + session = Session(connection, smb_real[0], smb_real[1], False) + tree = TreeConnect(session, smb_real[4]) + try: + session.connect() + tree.connect() + assert tree.encrypt_data is None + assert not tree.is_ca_share + assert not tree.is_dfs_share + assert not tree.is_scaleout_share + assert isinstance(tree.tree_connect_id, int) + finally: + connection.disconnect(True) + + def test_dialect_3_0_0(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_0) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + try: + session.connect() + tree.connect() + assert not tree.encrypt_data + assert not tree.is_ca_share + assert not tree.is_dfs_share + assert not tree.is_scaleout_share + assert isinstance(tree.tree_connect_id, int) + finally: + connection.disconnect(True) + + def test_dialect_3_0_2(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + try: + session.connect() + tree.connect() + assert not tree.encrypt_data + assert not tree.is_ca_share + assert not tree.is_dfs_share + assert not tree.is_scaleout_share + assert isinstance(tree.tree_connect_id, int) + finally: + connection.disconnect(True) + + def test_dialect_3_1_1(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_1_1) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[4]) + try: + session.connect() + tree.connect() + assert not tree.encrypt_data + assert not tree.is_ca_share + assert not tree.is_dfs_share + assert not tree.is_scaleout_share + assert isinstance(tree.tree_connect_id, int) + finally: + connection.disconnect(True) + + def test_dialect_2_encrypted_share(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_2_1_0) + session = Session(connection, smb_real[0], smb_real[1], False) + tree = TreeConnect(session, smb_real[5]) + try: + session.connect() + with pytest.raises(SMBResponseException) as exc: + tree.connect() + assert str(exc.value) == "Received unexpected status from the " \ + "server: (3221225506) " \ + "STATUS_ACCESS_DENIED: 0xc0000022" + finally: + connection.disconnect(True) + + def test_dialect_3_encrypted_share(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_1_1) + session = Session(connection, smb_real[0], smb_real[1]) + tree = TreeConnect(session, smb_real[5]) + try: + session.connect() + tree.connect() + assert tree.encrypt_data + assert not tree.is_ca_share + assert not tree.is_dfs_share + assert not tree.is_scaleout_share + assert isinstance(tree.tree_connect_id, int) + finally: + connection.disconnect(True) + + def test_secure_negotiation_verification_failed(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + connection.dialect = Dialects.SMB_3_0_0 + tree = TreeConnect(session, smb_real[4]) + try: + session.connect() + with pytest.raises(SMBException) as exc: + tree.connect() + assert "Secure negotiate failed to verify server dialect, " \ + "Actual: 770, Expected: 768" in str(exc.value) + finally: + connection.disconnect(True) + + def test_secure_ignore_negotiation_verification_failed(self, smb_real): + connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3]) + connection.connect(Dialects.SMB_3_0_2) + session = Session(connection, smb_real[0], smb_real[1]) + connection.dialect = Dialects.SMB_3_0_0 + tree = TreeConnect(session, smb_real[4]) + try: + session.connect() + tree.connect(False) + assert not tree.encrypt_data + assert not tree.is_ca_share + assert not tree.is_dfs_share + assert not tree.is_scaleout_share + assert isinstance(tree.tree_connect_id, int) + finally: + connection.disconnect(True) + tree.disconnect() # test that disconnect can be run mutliple times diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..a34ea92e --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,25 @@ +import os + +import pytest + + +@pytest.fixture(scope='module') +def smb_real(): + # for these tests to work the server at SMB_SERVER must support dialect + # 3.1.1, without this some checks will fail as we test 3.1.1 specific + # features + username = os.environ.get('SMB_USER', None) + password = os.environ.get('SMB_PASSWORD', None) + server = os.environ.get('SMB_SERVER', None) + port = os.environ.get('SMB_PORT', None) + share = os.environ.get('SMB_SHARE', None) + skip = os.environ.get('SMB_SKIP', "False") == "True" + + if username and password and server and port and share and not skip: + share = r"\\%s\%s" % (server, share) + encrypted_share = "%s-encrypted" % share + return username, password, server, int(port), share, encrypted_share + else: + pytest.skip("SMB_USER, SMB_PASSWORD, SMB_PORT, SMB_SHARE, " + "environment variables were not set, integration tests " + "will be skipped") diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..ddd4c198 --- /dev/null +++ b/tox.ini @@ -0,0 +1,7 @@ +[tox] +envlist = py26,py27,py34,py35,py36 + +[testenv] +deps= -rrequirements-test.txt +commands=py.test -v --pep8 --cov smbprotocol --cov-report term-missing +passenv=SMB_* \ No newline at end of file