diff --git a/lib/omniauth-oauth2.rb b/lib/omniauth-oauth2.rb index 9986eca..a625178 100644 --- a/lib/omniauth-oauth2.rb +++ b/lib/omniauth-oauth2.rb @@ -1,2 +1,3 @@ require "omniauth-oauth2/version" +require "omniauth/strategies/oauth2/state_container" require "omniauth/strategies/oauth2" diff --git a/lib/omniauth/strategies/oauth2.rb b/lib/omniauth/strategies/oauth2.rb index e445214..b2a3530 100644 --- a/lib/omniauth/strategies/oauth2.rb +++ b/lib/omniauth/strategies/oauth2.rb @@ -40,6 +40,7 @@ def self.inherited(subclass) }, :code_challenge_method => "S256", } + option :state_container, StateContainer.new attr_accessor :access_token @@ -60,7 +61,7 @@ def request_phase end def authorize_params # rubocop:disable Metrics/AbcSize, Metrics/MethodLength - options.authorize_params[:state] = SecureRandom.hex(24) + options.authorize_params[:state] = new_state if OmniAuth.config.test_mode @env ||= {} @@ -72,7 +73,7 @@ def authorize_params # rubocop:disable Metrics/AbcSize, Metrics/MethodLength .merge(pkce_authorize_params) session["omniauth.pkce.verifier"] = options.pkce_verifier if options.pkce - session["omniauth.state"] = params[:state] + options.state_container.store(self, params[:state]) params end @@ -83,7 +84,7 @@ def token_params def callback_phase # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity error = request.params["error_reason"] || request.params["error"] - if !options.provider_ignores_state && (request.params["state"].to_s.empty? || request.params["state"] != session.delete("omniauth.state")) + if !options.provider_ignores_state && (request.params["state"].to_s.empty? || request.params["state"] != options.state_container.take(self)) fail!(:csrf_detected, CallbackError.new(:csrf_detected, "CSRF detected")) elsif error fail!(error, CallbackError.new(request.params["error"], request.params["error_description"] || request.params["error_reason"], request.params["error_uri"])) @@ -100,6 +101,10 @@ def callback_phase # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexi fail!(:failed_to_connect, e) end + def new_state + SecureRandom.hex(24) + end + protected def pkce_authorize_params diff --git a/lib/omniauth/strategies/oauth2/state_container.rb b/lib/omniauth/strategies/oauth2/state_container.rb new file mode 100644 index 0000000..ad4f0b1 --- /dev/null +++ b/lib/omniauth/strategies/oauth2/state_container.rb @@ -0,0 +1,15 @@ +module OmniAuth + module Strategies + class OAuth2 + class StateContainer + def store(oauth2, state) + oauth2.session["omniauth.state"] = state + end + + def take(oauth2) + oauth2.session.delete("omniauth.state") + end + end + end + end +end diff --git a/spec/omniauth/strategies/oauth2/state_container_spec.rb b/spec/omniauth/strategies/oauth2/state_container_spec.rb new file mode 100644 index 0000000..92c9d12 --- /dev/null +++ b/spec/omniauth/strategies/oauth2/state_container_spec.rb @@ -0,0 +1,29 @@ +require "helper" + +describe OmniAuth::Strategies::OAuth2::StateContainer do + let(:state) { "random_state" } + let(:oauth2) { double("OAuth2", session: {}) } + + describe "#save_state" do + it "saves the state in the session" do + subject.store(oauth2, state) + + expect(oauth2.session["omniauth.state"]).to eq(state) + end + end + + describe "#take_state" do + before do + subject.store(oauth2, state) + end + + it "removes the state from the session" do + expect(oauth2.session).to include("omniauth.state") + + taken_state = subject.take(oauth2) + + expect(oauth2.session).not_to include("omniauth.state") + expect(taken_state).to eq(state) + end + end +end