From f308392fb22aa4001d16e5f807bfc60db6a9b4fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigve=20Sj=C3=B8vold?= <137484285+SigveSjovold@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:16:45 +0100 Subject: [PATCH] Add refresh token support for Azure AD OAuth provider (#1599) * feat(oauth): add refresh token support to Azure AD provider - Store refresh token from OAuth response in provider instance - Add refresh token to user metadata in get_user_info method - Enable offline_access scope to receive refresh tokens This enhancement allows applications to persist refresh tokens for maintaining long-term access to Azure AD resources. * Add 'offline_access' scope to enable refresh token support. * add refresh token support to AzureADHybridOauthProvider class * Apply Black formatting to oauth_providers.py --------- Co-authored-by: Mathijs de Bruin --- backend/chainlit/oauth_providers.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index 9481425a01..8054b2716f 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -183,7 +183,7 @@ def __init__(self): self.authorize_params = { "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), "response_type": "code", - "scope": "https://graph.microsoft.com/User.Read", + "scope": "https://graph.microsoft.com/User.Read offline_access", "response_mode": "query", } @@ -207,10 +207,12 @@ async def get_token(self, code: str, url: str): json = response.json() token = json["access_token"] + refresh_token = json.get("refresh_token") if not token: raise HTTPException( status_code=400, detail="Failed to get the access token" ) + self._refresh_token = refresh_token return token async def get_user_info(self, token: str): @@ -239,7 +241,11 @@ async def get_user_info(self, token: str): user = User( identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + metadata={ + "image": azure_user.get("image"), + "provider": "azure-ad", + "refresh_token": getattr(self, "_refresh_token", None), + }, ) return (azure_user, user) @@ -269,7 +275,7 @@ def __init__(self): self.authorize_params = { "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"), "response_type": "code id_token", - "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid", + "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid offline_access", "response_mode": "form_post", "nonce": nonce, } @@ -294,10 +300,12 @@ async def get_token(self, code: str, url: str): json = response.json() token = json["access_token"] + refresh_token = json.get("refresh_token") if not token: raise HTTPException( status_code=400, detail="Failed to get the access token" ) + self._refresh_token = refresh_token return token async def get_user_info(self, token: str): @@ -326,7 +334,11 @@ async def get_user_info(self, token: str): user = User( identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + metadata={ + "image": azure_user.get("image"), + "provider": "azure-ad", + "refresh_token": getattr(self, "_refresh_token", None), + }, ) return (azure_user, user)