Skip to content

Commit

Permalink
Add refresh token support for Azure AD OAuth provider (Chainlit#1599)
Browse files Browse the repository at this point in the history
* feat(oauth): add refresh token support to Azure AD provider

- Store refresh token from OAuth response in provider instance
- Add refresh token to user metadata in get_user_info method
- Enable offline_access scope to receive refresh tokens

This enhancement allows applications to persist refresh tokens
for maintaining long-term access to Azure AD resources.

* Add 'offline_access' scope to enable refresh token support.

* add refresh token support to AzureADHybridOauthProvider class

* Apply Black formatting to oauth_providers.py

---------

Co-authored-by: Mathijs de Bruin <[email protected]>
  • Loading branch information
SigveSjovold and dokterbob authored Jan 27, 2025
1 parent a3ecfb2 commit f308392
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions backend/chainlit/oauth_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(self):
self.authorize_params = {
"tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"),
"response_type": "code",
"scope": "https://graph.microsoft.com/User.Read",
"scope": "https://graph.microsoft.com/User.Read offline_access",
"response_mode": "query",
}

Expand All @@ -207,10 +207,12 @@ async def get_token(self, code: str, url: str):
json = response.json()

token = json["access_token"]
refresh_token = json.get("refresh_token")
if not token:
raise HTTPException(
status_code=400, detail="Failed to get the access token"
)
self._refresh_token = refresh_token
return token

async def get_user_info(self, token: str):
Expand Down Expand Up @@ -239,7 +241,11 @@ async def get_user_info(self, token: str):

user = User(
identifier=azure_user["userPrincipalName"],
metadata={"image": azure_user.get("image"), "provider": "azure-ad"},
metadata={
"image": azure_user.get("image"),
"provider": "azure-ad",
"refresh_token": getattr(self, "_refresh_token", None),
},
)
return (azure_user, user)

Expand Down Expand Up @@ -269,7 +275,7 @@ def __init__(self):
self.authorize_params = {
"tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"),
"response_type": "code id_token",
"scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid",
"scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid offline_access",
"response_mode": "form_post",
"nonce": nonce,
}
Expand All @@ -294,10 +300,12 @@ async def get_token(self, code: str, url: str):
json = response.json()

token = json["access_token"]
refresh_token = json.get("refresh_token")
if not token:
raise HTTPException(
status_code=400, detail="Failed to get the access token"
)
self._refresh_token = refresh_token
return token

async def get_user_info(self, token: str):
Expand Down Expand Up @@ -326,7 +334,11 @@ async def get_user_info(self, token: str):

user = User(
identifier=azure_user["userPrincipalName"],
metadata={"image": azure_user.get("image"), "provider": "azure-ad"},
metadata={
"image": azure_user.get("image"),
"provider": "azure-ad",
"refresh_token": getattr(self, "_refresh_token", None),
},
)
return (azure_user, user)

Expand Down

0 comments on commit f308392

Please sign in to comment.