Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure Auth Support #209

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@ verify_ssl = true
name = "pypi"

[packages]
jsonschema = "==4.19.2"
jsonschema = "==4.20.0"
jstyleson = "==0.0.2"
requests = "==2.31.0"
textx = "==3.1.1"
textx = "==4.0.1"
js2py = "==0.74"
requests-pkcs12 = "==1.22"
parsys-requests-unixsocket = "==0.3.1"
requests-aws4auth = "==1.2.3"
requests-ntlm = "==1.2.0"
restrictedpython = "==6.2"
faker = "==20.0.0"
restrictedpython = "==7.0"
faker = "==20.1.0"
requests-hawk = "==1.2.1"
pyyaml = "==6.0.1"
toml = "==0.10.2"
python-magic = "*"
msal = "==1.26.0"

[dev-packages]
python-magic = "*"
waitress = "==2.1.1"
flask = "*"
pyperf = "==2.6.2"
Expand Down
342 changes: 180 additions & 162 deletions Pipfile.lock

Large diffs are not rendered by default.

47 changes: 44 additions & 3 deletions dothttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@

from .dsl_jsonparser import json_or_array_to_json
from .exceptions import *
from .parse_models import MultidefHttp, AuthWrap, DigestAuth, BasicAuth, Line, NtlmAuthWrap, Query, Http, NameWrap, UrlWrap, Header, \
from .parse_models import AzureAuthCli, AzureAuthType, AzureAuthWrap, MultidefHttp, AuthWrap, DigestAuth, BasicAuth, Line, NtlmAuthWrap, Query, Http, NameWrap, UrlWrap, Header, \
MultiPartFile, FilesWrap, TripleOrDouble, Payload as ParsePayload, Certificate, P12Certificate, ExtraArg, \
AWS_REGION_LIST, AWS_SERVICES_LIST, AwsAuthWrap, TestScript, ScriptType, HawkAuth
AWS_REGION_LIST, AWS_SERVICES_LIST, AwsAuthWrap, TestScript, ScriptType, HawkAuth, AzureAuthCertificate, \
AzureAuthDeviceCode, AzureAuthServicePrincipal
from .property_schema import property_schema
from .property_util import PropertyProvider
try:
from .azure_auth import AzureAuth
except:
# this is for dothttp-wasm, where msal most likely not installed
AzureAuth = None

try:
import magic
Expand Down Expand Up @@ -320,6 +326,9 @@ def get_http_from_req(self):
aws_auth.signing_key.secret_key,
aws_auth.service,
aws_auth.region))
elif isinstance(self.auth, AzureAuth):
auth_wrap = AuthWrap(
azure_auth=self.auth.azure_auth_wrap)
certificate = None
if self.certificate:
certificate = Certificate(*self.certificate)
Expand Down Expand Up @@ -1051,12 +1060,44 @@ def load_auth(self):
aws_service
)
else:
# region and aws_service can be extracted from url
# aws service and region can be extracted from url
# somehow library is not supporting those
# with current state, we are not support this use case
# we may come back
# all four parameters are required and are to be non empty
raise DothttpAwsAuthException(access_id=access_id)
elif azure_auth := auth_wrap.azure_auth:
azure_auth: AzureAuthWrap = azure_auth
if sp_auth := azure_auth.azure_spsecret_auth:
azure_auth_wrap = AzureAuthWrap(azure_spsecret_auth=AzureAuthServicePrincipal(
tenant_id=self.get_updated_content(sp_auth.tenant_id),
client_id=self.get_updated_content(sp_auth.client_id),
client_secret=self.get_updated_content(sp_auth.client_secret),
scope = self.get_updated_content(sp_auth.scope or "https://management.azure.com/.default")
), azure_auth_type=AzureAuthType.SERVICE_PRINCIPAL)
elif cert_auth := azure_auth.azure_spcert_auth:
azure_auth_wrap = AzureAuthWrap(azure_spcert_auth=AzureAuthCertificate(
tenant_id=self.get_updated_content(cert_auth.tenant_id),
client_id=self.get_updated_content(cert_auth.client_id),
certificate_path=self.get_updated_content(cert_auth.certificate_path),
scope=self.get_updated_content(cert_auth.scope or "https://management.azure.com/.default")
), azure_auth_type=AzureAuthType.CERTIFICATE)
elif azure_auth.azure_cli_auth:
azure_auth_wrap = AzureAuthWrap(
azure_cli_auth=AzureAuthCli(
scope=self.get_updated_content(
azure_auth.azure_cli_auth.scope or "https://management.azure.com/.default"
)
), azure_auth_type=AzureAuthType.CLI)
elif azure_auth.auth.azure_device_code:
azure_auth_wrap = AzureAuthWrap(
azure_device_code=AzureAuthDeviceCode(
scope=self.get_updated_content(
azure_auth.azure_device_code.scope or "https://management.azure.com/.default"
)
), azure_auth_type=AzureAuthType.DEVICE_CODE)

self.httpdef.auth = AzureAuth(azure_auth_wrap)

def get_current_or_base(self, attr_key) -> Any:
if getattr(self.http, attr_key):
Expand Down
2 changes: 1 addition & 1 deletion dothttp/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.43a1'
__version__ = '0.0.43a2'
158 changes: 158 additions & 0 deletions dothttp/azure_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import json
import msal
import os
import pickle
import subprocess
import time
import logging

from requests.auth import AuthBase
from requests.models import PreparedRequest

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import pkcs12

from .exceptions import DothttpAzureAuthException
from .parse_models import AzureAuthWrap, AzureAuthType, AzureAuthSP, AzureAuthCertificate

AZURE_CLI_TOKEN_STORE_PATH = os.path.expanduser('~/.dothttp.azure-cli.pkl')

AZURE_SP_TOKEN_STORE_PATH = os.path.expanduser(
'~/.dothttp.msal_token_cache.pkl')

request_logger = logging.getLogger("request")


def load_private_key_and_thumbprint(cert_path, password=None):
extension = os.path.splitext(cert_path)[1].lower()
with open(cert_path, "rb") as cert_file:
cert_data = cert_file.read()

if extension == '.pem':
private_key = serialization.load_pem_private_key(
cert_data, password, default_backend())
cert = x509.load_pem_x509_certificate(cert_data, default_backend())
elif extension == '.cer':
cert = x509.load_der_x509_certificate(cert_data, default_backend())
private_key = None # .cer files do not contain private key
elif extension == '.pfx' or extension == '.p12':
private_key, cert, _ = pkcs12.load_key_and_certificates(
cert_data, password, default_backend())
else:
raise ValueError(f"Unsupported certificate format {extension}")

if private_key is not None:
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
else:
private_key_bytes = None

thumbprint = cert.fingerprint(hashes.SHA1()).hex()

return private_key_bytes, thumbprint

class AzureAuth(AuthBase):

def __init__(self, azure_auth_wrap: AzureAuthWrap):
self.azure_auth_wrap = azure_auth_wrap
self.token_cache = msal.SerializableTokenCache()
try:
# Try to load the token cache from a file
with open(AZURE_SP_TOKEN_STORE_PATH, 'rb') as token_cache_file:
self.token_cache.deserialize(
json.dumps(pickle.load(token_cache_file)))
except FileNotFoundError:
# If the file does not exist, initialize a new token cache
pass

def __call__(self, r: PreparedRequest) -> PreparedRequest:
if self.azure_auth_wrap.azure_auth_type == AzureAuthType.SERVICE_PRINCIPAL:
self.acquire_token_silently_or_ondemand(
r, self.azure_auth_wrap.azure_spsecret_auth)
self.save_token_cache()
elif self.azure_auth_wrap.azure_auth_type == AzureAuthType.CERTIFICATE:
self.acquire_token_silently_or_ondemand(
r, self.azure_auth_wrap.azure_spcert_auth)
self.save_token_cache()
# For device code and cli authentication, we use the access token directly
# in future we can use msal to get the access token for device code
elif self.azure_auth_wrap.azure_auth_type in [AzureAuthType.CLI, AzureAuthType.DEVICE_CODE]:
access_token = None
expires_on = None
# Try to load the access token and its expiry time from a file
scope = self.azure_auth_wrap.azure_cli_auth.scope if self.azure_auth_wrap.azure_cli_auth else self.azure_auth_wrap.azure_device_code.scope

if os.path.exists(AZURE_CLI_TOKEN_STORE_PATH):
request_logger.debug("azure cli token already exists, using")
with open(AZURE_CLI_TOKEN_STORE_PATH, 'rb') as token_file:
data = pickle.load(token_file)
scope_wise_store = data.get(scope, {})
access_token = scope_wise_store.get('access_token', None)
expires_on = scope_wise_store.get('expires_on', None)
# Get the current time in seconds since the Epoch
current_time = time.time()


# If the file does not exist or the token has expired, get a new access token
if not access_token or not expires_on or current_time >= expires_on:
request_logger.debug(
"azure cli token store cached not availabile or expired")
# get token from cli by invoking az account get-access-token
result = subprocess.run(
["az", "account", "get-access-token", "--scope", scope], capture_output=True, text=True)
result_json = json.loads(result.stdout)
access_token = result_json['accessToken']
# Convert the expiresOn field to seconds since the Epoch
expires_on = time.mktime(time.strptime(result_json['expiresOn'], '%Y-%m-%d %H:%M:%S.%f'))
# Save the new access token and its expiry time to the file
with open(AZURE_CLI_TOKEN_STORE_PATH, 'wb') as token_file:
scope_wise_store = dict()
scope_wise_store[scope] = {
'access_token': access_token, 'expires_on': expires_on}
pickle.dump(scope_wise_store, token_file)
request_logger.debug(
"computed or fetched azure cli token access bearer token and appeneded")
r.headers["Authorization"] = f"Bearer {access_token}"
return r

def acquire_token_silently_or_ondemand(self, r, auth_wrap: AzureAuthSP):
kwargs = {
"client_id": auth_wrap.client_id,
"authority": f"https://login.microsoftonline.com/{auth_wrap.tenant_id}",
"token_cache": self.token_cache
}
if isinstance(auth_wrap, AzureAuthCertificate):
try:
private_key_bytes, thumbprint = load_private_key_and_thumbprint(
auth_wrap.certificate_path, auth_wrap.certificate_password)
kwargs["client_credential"] = {
"private_key": private_key_bytes,
"thumbprint": thumbprint
}
except Exception as e:
request_logger.error(
"loading private key failed with error", e)
raise DothttpAzureAuthException(message=str(e))
else:
kwargs["client_credential"] = auth_wrap.client_secret
app = self.create_confidential_app(kwargs)
accounts = app.get_accounts()
if accounts:
result = app.acquire_token_silent(scopes=[auth_wrap.scope], account=accounts[0])
if not accounts or "access_token" not in result:
result = app.acquire_token_for_client(scopes=[auth_wrap.scope])
r.headers["Authorization"] = f"Bearer {result['access_token']}"

def create_confidential_app(self, kwargs):
return msal.ConfidentialClientApplication(**kwargs)

def save_token_cache(self):
with open(AZURE_SP_TOKEN_STORE_PATH, 'wb') as token_cache_file:
pickle.dump(json.loads(self.token_cache.serialize()),
token_cache_file)
6 changes: 6 additions & 0 deletions dothttp/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,9 @@ class ScriptException(DotHttpException):
"AWSAuth expects all(access_id, secret_token, region, service) to be non empty access_id:`{access_id}`")
class DothttpAwsAuthException(DotHttpException):
pass


@exception_wrapper(
"AzureAuth exception: {message}")
class DothttpAzureAuthException(DotHttpException):
pass
38 changes: 37 additions & 1 deletion dothttp/http.tx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ HEADER:
;

AUTHWRAP:
digest_auth = DIGESTAUTH | basic_auth = BASICAUTH | ntlm_auth = NTLMAUTH | hawk_auth=HAWKAUTH | aws_auth = AWSAUTH
digest_auth = DIGESTAUTH
| basic_auth = BASICAUTH
| ntlm_auth = NTLMAUTH
| hawk_auth=HAWKAUTH
| aws_auth = AWSAUTH
| azure_auth = AZUREAUTH
;

CERTAUTH:
Expand Down Expand Up @@ -80,6 +85,37 @@ AWSAUTH:
| 'awsauth' '(' ('access_id' '=' ) ? access_id=DotString ',' ('secret_key' '=')? secret_token=DotString ',' ('service' '=')? service=DotString (',' ('region' '=')? region=DotString)? ')'
;

AZUREAUTH:
azure_spcert_auth = AZURECERTIFICATEAUTH
// for service principal
| azure_spsecret_auth = AZURESERVICEPRINCIPALAUTH
// if you don't have both, use device_code flow
// in this case, we use azure cli to get auth_token
| azure_device_auth = AZUREDEVICEAUTH
| azure_cli_auth = AZURECLIAUTH
;

AZURESERVICEPRINCIPALAUTH:

'azurespsecret' '(' ('tenant_id' '=')? tenant_id=DotString ',' ('client_id' '=')? client_id=DotString ',' ('client_secret' '=')? client_secret=DotString (',' ('scope' '=')? scope=DotString )? ')'

;

AZURECERTIFICATEAUTH:
// for certificate based
'azurespcert' '(' ('tenant_id' '=')? tenant_id=DotString ',' ('client_id' '=')? client_id=DotString ',' ('certificate_path' '=')? certificate_path=DotString (',' ('scope' '=')? scope=DotString )?')'

;

AZUREDEVICEAUTH:
'azuredevice' '(' (('scope' '=')? scope=DotString )? ')'
;

AZURECLIAUTH:
'azurecli' '(' (('scope' '=')? scope=DotString )? ')'
;


EXTRA_ARG:
// there can be more
clear=CLEAR_SESSION | insecure=INSECURE
Expand Down
Loading
Loading