From 6262ebfb45a8aaa8cb3b48479a3f83ca4d552f9a Mon Sep 17 00:00:00 2001 From: Pat Buxton Date: Mon, 17 Jun 2024 10:05:01 +0100 Subject: [PATCH] Try to check token for all requests --- plaid/auth_oidc.py | 2 +- plaid/security.py | 50 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/plaid/auth_oidc.py b/plaid/auth_oidc.py index ee921d5a5db7e..dd6a9821b5232 100644 --- a/plaid/auth_oidc.py +++ b/plaid/auth_oidc.py @@ -69,5 +69,5 @@ class PlaidAuthOAuthView(AuthOAuthView): @expose("/login/") def login(self, provider=None): if provider is None: - return super().login(provider='plaid-keycloak') + return super().login(provider='plaidkeycloak') return super().login(provider=provider) diff --git a/plaid/security.py b/plaid/security.py index 3c140fce3fb5f..df5df77440931 100644 --- a/plaid/security.py +++ b/plaid/security.py @@ -15,6 +15,7 @@ from flask_appbuilder import Model from flask_appbuilder.security.manager import AUTH_OID, AUTH_OAUTH from authlib.integrations.flask_client import OAuth +from authlib.integrations.flask_client import token_update from requests.exceptions import HTTPError from plaidcloud.rpc.connection.jsonrpc import SimpleRPC @@ -64,12 +65,30 @@ def __init__(self, appbuilder): client_kwargs=self.oidc_params['client_kwargs'], ) self.authoidview = AuthOIDCView + if self.auth_type == AUTH_OAUTH: self.authoauthview = PlaidAuthOAuthView + @token_update.connect_via(appbuilder) + def on_token_update(sender, name, token, refresh_token=None, access_token=None): + # if refresh_token: + # item = OAuth2Token.find(name=name, refresh_token=refresh_token) + # elif access_token: + # item = OAuth2Token.find(name=name, access_token=access_token) + # else: + # return + # + # # update old token + # item.access_token = token['access_token'] + # item.refresh_token = token.get('refresh_token') + # item.expires_at = token['expires_at'] + # item.save() + log.info(f'Updated token for {name} - {repr(token)}') + self.appbuilder.sm.set_oauth_session(name, token) + def oauth_user_info(self, provider, response=None): # logging.debug("Oauth2 provider: {0}.".format(provider)) - if provider == 'plaid-keycloak': + if provider == 'plaidkeycloak': me = self.appbuilder.sm.oauth_remotes[provider].get("userinfo") me.raise_for_status() data = me.json() @@ -370,3 +389,32 @@ def add_user_to_project(self, user, project_id): log.debug( "Appended %s to %s roles list.", role.name, user.username ) + + def has_access(self, permission_name: str, view_name: str) -> bool: + # check token expiry and logout, then continue previous auth check + if self.auth_type == AUTH_OAUTH: + if 'oauth' in session: + token, secret = session['oauth'] + # provider = session["oauth_provider"] + if not token_is_valid(token): + logout_user() + session.clear() + + elif self.auth_type == AUTH_OID: + if 'token' in session: + token = session['token'] + if not token_is_valid(token): + logout_user() + session.clear() + + return super().has_access(permission_name, view_name) + + +def token_is_valid(access_token): + try: + decoded_token = jwt.decode(access_token, options={'verify_signature': False}) + expiration_timestamp = decoded_token['exp'] + current_timestamp = time.time() + return expiration_timestamp > current_timestamp + except (jwt.exceptions.DecodeError, KeyError): + return False