Skip to content

Commit

Permalink
Fixes jpadilla#964: Validate key against allowed types for Algorithm …
Browse files Browse the repository at this point in the history
…family
  • Loading branch information
pachewise committed Sep 12, 2024
1 parent 467c748 commit 1dc9bb6
Showing 1 changed file with 55 additions and 30 deletions.
85 changes: 55 additions & 30 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,29 @@
load_ssh_public_key,
)

# pyjwt-964: we use these both for type checking below, as well as for validating the key passed in.
ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey, RSAPublicKey)
ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey, EllipticCurvePublicKey)
ALLOWED_OKP_KEY_TYPES = (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey)
ALLOWED_KEY_TYPES = ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES
ALLOWED_PRIVATE_KEY_TYPES = (RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey)
ALLOWED_PUBLIC_KEY_TYPES = (RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey)

has_crypto = True
except ModuleNotFoundError:
has_crypto = False


if TYPE_CHECKING:
# TODO: should we move this to the top-level?
from typing import Union
# Type aliases for convenience in algorithms method signatures
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
AllowedOKPKeys = (
Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
)
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
AllowedPrivateKeys = (
RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
)
AllowedPublicKeys = (
RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
)
AllowedRSAKeys = Union[ALLOWED_RSA_KEY_TYPES]
AllowedECKeys = Union[ALLOWED_EC_KEY_TYPES]
AllowedOKPKeys = Union[ALLOWED_OKP_KEY_TYPES]
AllowedKeys = Union[ALLOWED_KEY_TYPES]
AllowedPrivateKeys = Union[ALLOWED_PRIVATE_KEY_TYPES]
AllowedPublicKeys = Union[ALLOWED_PUBLIC_KEY_TYPES]


requires_cryptography = {
Expand Down Expand Up @@ -141,6 +145,9 @@ class Algorithm(ABC):
The interface for an algorithm used to sign and verify tokens.
"""

# pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
_crypto_key_types: tuple[AllowedKeys, ...] = None

def compute_hash_digest(self, bytestr: bytes) -> bytes:
"""
Compute a hash digest using the specified algorithm's hash algorithm.
Expand All @@ -163,6 +170,26 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
else:
return bytes(hash_alg(bytestr).digest())

def check_crypto_key_type(self, key: Any):
"""Check that the key belongs to the right cryptographic family.
Note that this method only works when cryptography is installed.
Args:
key (Any): Potentially a cryptography key
Raises:
InvalidKeyError: if the key doesn't match the expected key classes
"""
if not has_crypto or self._crypto_key_types is None:
return

# TODO check for algo_type? (e.g., SHA256 vs SHA384)
if not isinstance(key, self._crypto_key_types):
valid_classes = (cls.__name__ for cls in self._crypto_key_types)
actual_class = key.__class__.__name__
self_class = self.__class__.__name__
raise InvalidKeyError(f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}")

@abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Expand Down Expand Up @@ -323,11 +350,13 @@ class RSAAlgorithm(Algorithm):
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512

_key_types = ALLOWED_RSA_KEY_TYPES

def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
if isinstance(key, self._crypto_key_types):
return key

if not isinstance(key, (bytes, str)):
Expand All @@ -337,17 +366,20 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:

try:
if key_bytes.startswith(b"ssh-rsa"):
return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
loaded_key = cast(RSAPublicKey, load_ssh_public_key(key_bytes))
else:
return cast(
loaded_key = cast(
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
)
except ValueError:
try:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
loaded_key = cast(RSAPublicKey, load_pem_public_key(key_bytes))
except (ValueError, UnsupportedAlgorithm):
raise InvalidKeyError("Could not parse the provided public key.")

self.check_crypto_key_type(loaded_key)
return loaded_key

@overload
@staticmethod
def to_jwk(
Expand Down Expand Up @@ -491,11 +523,13 @@ class ECAlgorithm(Algorithm):
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512

_crypto_key_types = ALLOWED_EC_KEY_TYPES

def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
if isinstance(key, self._crypto_key_types):
return key

if not isinstance(key, (bytes, str)):
Expand All @@ -515,12 +549,7 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]

# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
):
raise InvalidKeyError(
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
)
self.check_crypto_key_type(crypto_key)

return crypto_key

Expand Down Expand Up @@ -700,6 +729,8 @@ class OKPAlgorithm(Algorithm):
This class requires ``cryptography>=2.6`` to be installed.
"""

_crypto_key_types = ALLOWED_OKP_KEY_TYPES

def __init__(self, **kwargs: Any) -> None:
pass

Expand All @@ -716,13 +747,7 @@ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]

# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
key,
(Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
):
raise InvalidKeyError(
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
)
self.check_crypto_key_type(key)

return key

Expand Down

0 comments on commit 1dc9bb6

Please sign in to comment.