diff --git a/.rubocop.yml b/.rubocop.yml index 42188a85..987342cf 100644 --- a/.rubocop.yml +++ b/.rubocop.yml @@ -29,7 +29,7 @@ Metrics/AbcSize: Max: 25 Metrics/ClassLength: - Max: 112 + Max: 121 Metrics/ModuleLength: Max: 100 diff --git a/CHANGELOG.md b/CHANGELOG.md index 40c3b871..6bd19790 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ **Features:** +- Support custom algorithms by passing algorithm objects[#512](https://github.com/jwt/ruby-jwt/pull/512) ([@anakinj](https://github.com/anakinj)). - Your contribution here **Fixes and enhancements:** diff --git a/README.md b/README.md index 07aa62bc..a985f19d 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,33 @@ decoded_token = JWT.decode token, rsa_public, true, { algorithm: 'PS256' } puts decoded_token ``` +### **Custom algorithms** + +An object implementing custom signing or verification behaviour can be passed in the `algorithm` option when encoding and decoding. The given object needs to implement the method `valid_alg?` and `verify` and/or `alg` and `sign`, depending if object is used for encoding or decoding. + +```ruby + module CustomHS512Algorithm + def self.alg + 'HS512' + end + + def self.valid_alg?(alg_to_validate) + alg_to_validate == alg + end + + def self.sign(data:, signing_key:) + OpenSSL::HMAC.digest(OpenSSL::Digest.new('sha512'), data, signing_key) + end + + def self.verify(data:, signature:, verification_key:) + ::OpenSSL.secure_compare(sign(data: data, signing_key: verification_key), signature) + end + end + + token = ::JWT.encode({'pay' => 'load'}, 'secret', CustomHS512Algorithm) + payload, header = ::JWT.decode(token, 'secret', true, algorithm: CustomHS512Algorithm) +``` + ## Support for reserved claim names JSON Web Token defines some reserved claim names and defines how they should be used. JWT supports these reserved claim names: diff --git a/lib/jwt/algos.rb b/lib/jwt/algos.rb index aa46a60f..3420ae65 100644 --- a/lib/jwt/algos.rb +++ b/lib/jwt/algos.rb @@ -1,5 +1,13 @@ # frozen_string_literal: true +begin + require 'rbnacl' +rescue LoadError + raise if defined?(RbNaCl) +end +require 'openssl' + +require 'jwt/security_utils' require 'jwt/algos/hmac' require 'jwt/algos/eddsa' require 'jwt/algos/ecdsa' @@ -7,10 +15,9 @@ require 'jwt/algos/ps' require 'jwt/algos/none' require 'jwt/algos/unsupported' +require 'jwt/algos/algo_wrapper' -# JWT::Signature module module JWT - # Signature logic for JWT module Algos extend self @@ -28,14 +35,23 @@ def find(algorithm) indexed[algorithm && algorithm.downcase] end + def create(algorithm) + Algos::AlgoWrapper.new(*find(algorithm)) + end + + def implementation?(algorithm) + (algorithm.respond_to?(:valid_alg?) && algorithm.respond_to?(:verify)) || + (algorithm.respond_to?(:alg) && algorithm.respond_to?(:sign)) + end + private def indexed @indexed ||= begin - fallback = [Algos::Unsupported, nil] - ALGOS.each_with_object(Hash.new(fallback)) do |alg, hash| - alg.const_get(:SUPPORTED).each do |code| - hash[code.downcase] = [alg, code] + fallback = [nil, Algos::Unsupported] + ALGOS.each_with_object(Hash.new(fallback)) do |cls, hash| + cls.const_get(:SUPPORTED).each do |alg| + hash[alg.downcase] = [alg, cls] end end end diff --git a/lib/jwt/algos/algo_wrapper.rb b/lib/jwt/algos/algo_wrapper.rb new file mode 100644 index 00000000..caf823e6 --- /dev/null +++ b/lib/jwt/algos/algo_wrapper.rb @@ -0,0 +1,30 @@ +# frozen_string_literal: true + +module JWT + module Algos + class AlgoWrapper + attr_reader :alg, :cls + + def initialize(alg, cls) + @alg = alg + @cls = cls + end + + def valid_alg?(alg_to_check) + alg.casecmp(alg_to_check)&.zero? == true + end + + def sign(data:, signing_key:) + cls.sign(alg, data, signing_key) + end + + def verify(data:, signature:, verification_key:) + cls.verify(alg, verification_key, data, signature) + rescue OpenSSL::PKey::PKeyError # These should be moved to the algorithms that actually need this, but left here to ensure nothing will break. + raise JWT::VerificationError, 'Signature verification raised' + ensure + OpenSSL.errors.clear + end + end + end +end diff --git a/lib/jwt/algos/ecdsa.rb b/lib/jwt/algos/ecdsa.rb index 09a149e4..ea154bd3 100644 --- a/lib/jwt/algos/ecdsa.rb +++ b/lib/jwt/algos/ecdsa.rb @@ -30,8 +30,7 @@ module Ecdsa SUPPORTED = NAMED_CURVES.map { |_, c| c[:algorithm] }.uniq.freeze - def sign(to_sign) - algorithm, msg, key = to_sign.values + def sign(algorithm, msg, key) curve_definition = curve_by_name(key.group.curve_name) key_algorithm = curve_definition[:algorithm] if algorithm != key_algorithm @@ -42,8 +41,7 @@ def sign(to_sign) SecurityUtils.asn1_to_raw(key.dsa_sign_asn1(digest.digest(msg)), key) end - def verify(to_verify) - algorithm, public_key, signing_input, signature = to_verify.values + def verify(algorithm, public_key, signing_input, signature) curve_definition = curve_by_name(public_key.group.curve_name) key_algorithm = curve_definition[:algorithm] if algorithm != key_algorithm diff --git a/lib/jwt/algos/eddsa.rb b/lib/jwt/algos/eddsa.rb index d9f626e3..cf21ad94 100644 --- a/lib/jwt/algos/eddsa.rb +++ b/lib/jwt/algos/eddsa.rb @@ -7,8 +7,7 @@ module Eddsa SUPPORTED = %w[ED25519 EdDSA].freeze - def sign(to_sign) - algorithm, msg, key = to_sign.values + def sign(algorithm, msg, key) if key.class != RbNaCl::Signatures::Ed25519::SigningKey raise EncodeError, "Key given is a #{key.class} but has to be an RbNaCl::Signatures::Ed25519::SigningKey" end @@ -19,8 +18,7 @@ def sign(to_sign) key.sign(msg) end - def verify(to_verify) - algorithm, public_key, signing_input, signature = to_verify.values + def verify(algorithm, public_key, signing_input, signature) unless SUPPORTED.map(&:downcase).map(&:to_sym).include?(algorithm.downcase.to_sym) raise IncorrectAlgorithm, "payload algorithm is #{algorithm} but #{key.primitive} signing key was provided" end diff --git a/lib/jwt/algos/hmac.rb b/lib/jwt/algos/hmac.rb index b298f518..adb3ec62 100644 --- a/lib/jwt/algos/hmac.rb +++ b/lib/jwt/algos/hmac.rb @@ -7,8 +7,7 @@ module Hmac SUPPORTED = %w[HS256 HS512256 HS384 HS512].freeze - def sign(to_sign) - algorithm, msg, key = to_sign.values + def sign(algorithm, msg, key) key ||= '' authenticator, padded_key = SecurityUtils.rbnacl_fixup(algorithm, key) if authenticator && padded_key @@ -18,8 +17,7 @@ def sign(to_sign) end end - def verify(to_verify) - algorithm, public_key, signing_input, signature = to_verify.values + def verify(algorithm, public_key, signing_input, signature) authenticator, padded_key = SecurityUtils.rbnacl_fixup(algorithm, public_key) if authenticator && padded_key begin @@ -28,7 +26,7 @@ def verify(to_verify) false end else - SecurityUtils.secure_compare(signature, sign(JWT::Signature::ToSign.new(algorithm, signing_input, public_key))) + SecurityUtils.secure_compare(signature, sign(algorithm, signing_input, public_key)) end end end diff --git a/lib/jwt/algos/none.rb b/lib/jwt/algos/none.rb index 7de3733a..84ea99da 100644 --- a/lib/jwt/algos/none.rb +++ b/lib/jwt/algos/none.rb @@ -7,7 +7,9 @@ module None SUPPORTED = %w[none].freeze - def sign(*); end + def sign(*) + '' + end def verify(*) true diff --git a/lib/jwt/algos/ps.rb b/lib/jwt/algos/ps.rb index 20d182c0..a30c3268 100644 --- a/lib/jwt/algos/ps.rb +++ b/lib/jwt/algos/ps.rb @@ -9,11 +9,9 @@ module Ps SUPPORTED = %w[PS256 PS384 PS512].freeze - def sign(to_sign) + def sign(algorithm, msg, key) require_openssl! - algorithm, msg, key = to_sign.values - key_class = key.class raise EncodeError, "The given key is a #{key_class}. It has to be an OpenSSL::PKey::RSA instance." if key_class == String @@ -23,10 +21,10 @@ def sign(to_sign) key.sign_pss(translated_algorithm, msg, salt_length: :digest, mgf1_hash: translated_algorithm) end - def verify(to_verify) + def verify(algorithm, public_key, signing_input, signature) require_openssl! - SecurityUtils.verify_ps(to_verify.algorithm, to_verify.public_key, to_verify.signing_input, to_verify.signature) + SecurityUtils.verify_ps(algorithm, public_key, signing_input, signature) end def require_openssl! diff --git a/lib/jwt/algos/rsa.rb b/lib/jwt/algos/rsa.rb index 252c8f88..e7e54daa 100644 --- a/lib/jwt/algos/rsa.rb +++ b/lib/jwt/algos/rsa.rb @@ -7,15 +7,14 @@ module Rsa SUPPORTED = %w[RS256 RS384 RS512].freeze - def sign(to_sign) - algorithm, msg, key = to_sign.values + def sign(algorithm, msg, key) raise EncodeError, "The given key is a #{key.class}. It has to be an OpenSSL::PKey::RSA instance." if key.instance_of?(String) key.sign(OpenSSL::Digest.new(algorithm.sub('RS', 'sha')), msg) end - def verify(to_verify) - SecurityUtils.verify_rsa(to_verify.algorithm, to_verify.public_key, to_verify.signing_input, to_verify.signature) + def verify(algorithm, public_key, signing_input, signature) + SecurityUtils.verify_rsa(algorithm, public_key, signing_input, signature) end end end diff --git a/lib/jwt/decode.rb b/lib/jwt/decode.rb index 2852f7f5..9445c23f 100644 --- a/lib/jwt/decode.rb +++ b/lib/jwt/decode.rb @@ -2,9 +2,9 @@ require 'json' -require 'jwt/signature' require 'jwt/verify' require 'jwt/x5c_key_finder' + # JWT::Decode module module JWT # Decoding logic for JWT @@ -24,7 +24,7 @@ def initialize(jwt, key, verify, options, &keyfinder) def decode_segments validate_segment_count! if @verify - decode_crypto + decode_signature verify_algo set_key verify_signature @@ -51,8 +51,8 @@ def verify_signature def verify_algo raise(JWT::IncorrectAlgorithm, 'An algorithm must be specified') if allowed_algorithms.empty? - raise(JWT::IncorrectAlgorithm, 'Token is missing alg header') unless algorithm - raise(JWT::IncorrectAlgorithm, 'Expected a different algorithm') unless options_includes_algo_in_header? + raise(JWT::IncorrectAlgorithm, 'Token is missing alg header') unless alg_in_header + raise(JWT::IncorrectAlgorithm, 'Expected a different algorithm') unless valid_alg_in_header? end def set_key @@ -64,27 +64,50 @@ def set_key end def verify_signature_for?(key) - Signature.verify(algorithm, key, signing_input, @signature) + allowed_algorithms.any? do |alg| + alg.verify(data: signing_input, signature: @signature, verification_key: key) + end + end + + def valid_alg_in_header? + allowed_algorithms.any? { |alg| alg.valid_alg?(alg_in_header) } end - def options_includes_algo_in_header? - allowed_algorithms.any? { |alg| alg.casecmp(algorithm).zero? } + # Order is very important - first check for string keys, next for symbols + ALGORITHM_KEYS = ['algorithm', + :algorithm, + 'algorithms', + :algorithms].freeze + + def given_algorithms + ALGORITHM_KEYS.each do |alg_key| + alg = @options[alg_key] + return Array(alg) if alg + end + [] end def allowed_algorithms - # Order is very important - first check for string keys, next for symbols - algos = if @options.key?('algorithm') - @options['algorithm'] - elsif @options.key?(:algorithm) - @options[:algorithm] - elsif @options.key?('algorithms') - @options['algorithms'] - elsif @options.key?(:algorithms) - @options[:algorithms] - else - [] + @allowed_algorithms ||= resolve_allowed_algorithms + end + + def resolve_allowed_algorithms + algs = given_algorithms.map do |alg| + if Algos.implementation?(alg) + alg + else + Algos.create(alg) + end end - Array(algos) + + sort_by_alg_header(algs) + end + + # Move algorithms matching the JWT alg header to the beginning of the list + def sort_by_alg_header(algs) + return algs if algs.size <= 1 + + algs.partition { |alg| alg.valid_alg?(alg_in_header) }.flatten end def find_key(&keyfinder) @@ -113,14 +136,14 @@ def segment_length end def none_algorithm? - algorithm == 'none' + alg_in_header == 'none' end - def decode_crypto + def decode_signature @signature = ::JWT::Base64.url_decode(@segments[2] || '') end - def algorithm + def alg_in_header header['alg'] end diff --git a/lib/jwt/encode.rb b/lib/jwt/encode.rb index c3a3f80c..252ddf9b 100644 --- a/lib/jwt/encode.rb +++ b/lib/jwt/encode.rb @@ -1,28 +1,35 @@ # frozen_string_literal: true -require_relative './algos' -require_relative './claims_validator' +require_relative 'algos' +require_relative 'claims_validator' # JWT::Encode module module JWT # Encoding logic for JWT class Encode - ALG_NONE = 'none' - ALG_KEY = 'alg' + ALG_KEY = 'alg' def initialize(options) - @payload = options[:payload] - @key = options[:key] - _, @algorithm = Algos.find(options[:algorithm]) - @headers = options[:headers].transform_keys(&:to_s) + @payload = options[:payload] + @key = options[:key] + @algorithm = resolve_algorithm(options[:algorithm]) + @headers = options[:headers].transform_keys(&:to_s) + @headers[ALG_KEY] = @algorithm.alg end def segments - @segments ||= combine(encoded_header_and_payload, encoded_signature) + validate_claims! + combine(encoded_header_and_payload, encoded_signature) end private + def resolve_algorithm(algorithm) + return algorithm if Algos.implementation?(algorithm) + + Algos.create(algorithm) + end + def encoded_header @encoded_header ||= encode_header end @@ -40,25 +47,28 @@ def encoded_header_and_payload end def encode_header - @headers[ALG_KEY] = @algorithm - encode(@headers) + encode_data(@headers) end def encode_payload - if @payload.is_a?(Hash) - ClaimsValidator.new(@payload).validate! - end + encode_data(@payload) + end - encode(@payload) + def signature + @algorithm.sign(data: encoded_header_and_payload, signing_key: @key) end - def encode_signature - return '' if @algorithm == ALG_NONE + def validate_claims! + return unless @payload.is_a?(Hash) + + ClaimsValidator.new(@payload).validate! + end - ::JWT::Base64.url_encode(JWT::Signature.sign(@algorithm, encoded_header_and_payload, @key)) + def encode_signature + ::JWT::Base64.url_encode(signature) end - def encode(data) + def encode_data(data) ::JWT::Base64.url_encode(JWT::JSON.generate(data)) end diff --git a/lib/jwt/signature.rb b/lib/jwt/signature.rb deleted file mode 100644 index e46a5f8f..00000000 --- a/lib/jwt/signature.rb +++ /dev/null @@ -1,35 +0,0 @@ -# frozen_string_literal: true - -require 'jwt/security_utils' -require 'openssl' -require 'jwt/algos' -begin - require 'rbnacl' -rescue LoadError - raise if defined?(RbNaCl) -end - -# JWT::Signature module -module JWT - # Signature logic for JWT - module Signature - module_function - - ToSign = Struct.new(:algorithm, :msg, :key) - ToVerify = Struct.new(:algorithm, :public_key, :signing_input, :signature) - - def sign(algorithm, msg, key) - algo, code = Algos.find(algorithm) - algo.sign ToSign.new(code, msg, key) - end - - def verify(algorithm, key, signing_input, signature) - algo, code = Algos.find(algorithm) - algo.verify(ToVerify.new(code, key, signing_input, signature)) - rescue OpenSSL::PKey::PKeyError - raise JWT::VerificationError, 'Signature verification raised' - ensure - OpenSSL.errors.clear - end - end -end diff --git a/spec/integration/readme_examples_spec.rb b/spec/integration/readme_examples_spec.rb index 75c65e64..009aadef 100644 --- a/spec/integration/readme_examples_spec.rb +++ b/spec/integration/readme_examples_spec.rb @@ -352,4 +352,29 @@ expect(jwk_hash[:kid].size).to eq(43) end end + + context 'custom algorithm example' do + it 'allows a module to be used as algorithm on encode and decode' do + custom_hs512_alg = Module.new do + def self.alg + 'HS512' + end + + def self.valid_alg?(alg_to_validate) + alg_to_validate == alg + end + + def self.sign(data:, signing_key:) + OpenSSL::HMAC.digest(OpenSSL::Digest.new('sha512'), data, signing_key) + end + + def self.verify(data:, signature:, verification_key:) + sign(data: data, signing_key: verification_key) == signature + end + end + + token = ::JWT.encode({ 'pay' => 'load' }, 'secret', custom_hs512_alg) + _payload, _header = ::JWT.decode(token, 'secret', true, algorithm: custom_hs512_alg) + end + end end diff --git a/spec/jwt_spec.rb b/spec/jwt_spec.rb index 0f4d4b88..7f672574 100644 --- a/spec/jwt_spec.rb +++ b/spec/jwt_spec.rb @@ -729,7 +729,7 @@ let(:token) { JWT.encode(payload, 'HS256', 'HS256') } it 'decodes the token but does not pass the payload' do expect(JWT.decode(token, nil, true, algorithm: 'HS256') do |header, token_payload, nothing| - expect(token_payload).to eq(nil) # This behaviour is not correct, the payload should be available in the keyfinder + expect(token_payload).to eq(nil) # This behaviour is not correct, the payload should be available in the keyfinder expect(nothing).to eq(nil) header['alg'] end).to include(payload) @@ -771,4 +771,101 @@ expect(JWT.decode(token, 'secret', true, algorithm: 'HS256')).to include(payload) end end + + context 'when multiple algorithms given' do + let(:token) { JWT.encode(payload, 'secret', 'HS256') } + + it 'starts trying with the algorithm referred in the header' do + expect(::JWT::Algos::Rsa).not_to receive(:verify) + JWT.decode(token, 'secret', true, algorithm: ['RS512', 'HS256']) + end + end + + context 'when algorithm is a custom class' do + let(:custom_algorithm) do + Class.new do + attr_reader :alg + + def initialize(signature: 'custom_signature', alg: 'custom') + @signature = signature + @alg = alg + end + + def sign(*) + @signature + end + + def verify(data:, signature:, verification_key:) # rubocop:disable Lint/UnusedMethodArgument + signature == @signature + end + + def valid_alg?(alg) + alg == self.alg + end + end + end + + let(:token) { JWT.encode(payload, 'secret', custom_algorithm.new) } + let(:expected_token) { 'eyJhbGciOiJjdXN0b20ifQ.eyJ1c2VyX2lkIjoic29tZUB1c2VyLnRsZCJ9.Y3VzdG9tX3NpZ25hdHVyZQ' } + + it 'can be used for encoding' do + expect(token).to eq(expected_token) + end + + it 'can be used for decoding' do + expect(JWT.decode(token, 'secret', true, algorithm: custom_algorithm.new)).to eq([payload, { 'alg' => 'custom' }]) + end + + context 'when multiple custom algorithms are given for decoding' do + it 'tries until the first match' do + expect(JWT.decode(token, 'secret', true, algorithms: [custom_algorithm.new(signature: 'not_this'), custom_algorithm.new])).to eq([payload, { 'alg' => 'custom' }]) + end + end + + context 'when alg is not matching' do + it 'fails the validation process' do + expect { JWT.decode(token, 'secret', true, algorithms: custom_algorithm.new(alg: 'not_a_match')) }.to raise_error(JWT::IncorrectAlgorithm, 'Expected a different algorithm') + end + end + + context 'when signature is not matching' do + it 'fails the validation process' do + expect { JWT.decode(token, 'secret', true, algorithms: custom_algorithm.new(signature: 'not_a_match')) }.to raise_error(JWT::VerificationError, 'Signature verification failed') + end + end + + context 'when #sign method is missing' do + before do + custom_algorithm.instance_eval do + remove_method :sign + end + end + + # This behaviour should be somehow nicer + it 'raises an error on encoding' do + expect { token }.to raise_error(NoMethodError) + end + + it 'allows decoding' do + expect(JWT.decode(expected_token, 'secret', true, algorithm: custom_algorithm.new)).to eq([payload, { 'alg' => 'custom' }]) + end + end + + context 'when #verify method is missing' do + before do + custom_algorithm.instance_eval do + remove_method :verify + end + end + + it 'can be used for encoding' do + expect(token).to eq(expected_token) + end + + # This behaviour should be somehow nicer + it 'raises error on decoding' do + expect { JWT.decode(expected_token, 'secret', true, algorithm: custom_algorithm.new) }.to raise_error(NoMethodError) + end + end + end end