diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index 9481425a01..8054b2716f 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -183,7 +183,7 @@ def __init__(self): self.authorize_params = { "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), "response_type": "code", - "scope": "https://graph.microsoft.com/User.Read", + "scope": "https://graph.microsoft.com/User.Read offline_access", "response_mode": "query", } @@ -207,10 +207,12 @@ async def get_token(self, code: str, url: str): json = response.json() token = json["access_token"] + refresh_token = json.get("refresh_token") if not token: raise HTTPException( status_code=400, detail="Failed to get the access token" ) + self._refresh_token = refresh_token return token async def get_user_info(self, token: str): @@ -239,7 +241,11 @@ async def get_user_info(self, token: str): user = User( identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + metadata={ + "image": azure_user.get("image"), + "provider": "azure-ad", + "refresh_token": getattr(self, "_refresh_token", None), + }, ) return (azure_user, user) @@ -269,7 +275,7 @@ def __init__(self): self.authorize_params = { "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"), "response_type": "code id_token", - "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid", + "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid offline_access", "response_mode": "form_post", "nonce": nonce, } @@ -294,10 +300,12 @@ async def get_token(self, code: str, url: str): json = response.json() token = json["access_token"] + refresh_token = json.get("refresh_token") if not token: raise HTTPException( status_code=400, detail="Failed to get the access token" ) + self._refresh_token = refresh_token return token async def get_user_info(self, token: str): @@ -326,7 +334,11 @@ async def get_user_info(self, token: str): user = User( identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + metadata={ + "image": azure_user.get("image"), + "provider": "azure-ad", + "refresh_token": getattr(self, "_refresh_token", None), + }, ) return (azure_user, user)