Skip to content

Commit

Permalink
ref: match signatures of refresh_identity and get_refresh_token_params
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile-sentry committed Feb 21, 2025
1 parent d8d141a commit 19672f8
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 28 deletions.
3 changes: 2 additions & 1 deletion src/sentry/auth/providers/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sentry.auth.provider import MigratingIdentityId, Provider
from sentry.auth.providers.saml2.provider import Attributes, SAML2Provider
from sentry.auth.view import AuthView
from sentry.models.authidentity import AuthIdentity

PLACEHOLDER_TEMPLATE = '<form method="POST"><input type="email" name="email" /></form>'

Expand Down Expand Up @@ -37,7 +38,7 @@ def build_identity(self, state):
"name": "Dummy",
}

def refresh_identity(self, auth_identity):
def refresh_identity(self, auth_identity: AuthIdentity) -> None:
pass

def build_config(self, state):
Expand Down
3 changes: 2 additions & 1 deletion src/sentry/auth/providers/github/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sentry.auth.exceptions import IdentityNotValid
from sentry.auth.providers.oauth2 import OAuth2Callback, OAuth2Login, OAuth2Provider
from sentry.auth.services.auth.model import RpcAuthProvider
from sentry.models.authidentity import AuthIdentity
from sentry.organizations.services.organization.model import RpcOrganization
from sentry.plugins.base.response import DeferredResponse

Expand Down Expand Up @@ -76,7 +77,7 @@ def build_identity(self, state):
"data": self.get_oauth_data(data),
}

def refresh_identity(self, auth_identity):
def refresh_identity(self, auth_identity: AuthIdentity) -> None:
with GitHubClient(auth_identity.data["access_token"]) as client:
try:
if not client.is_org_member(self.org["id"]):
Expand Down
5 changes: 3 additions & 2 deletions src/sentry/auth/providers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sentry.auth.provider import Provider
from sentry.auth.view import AuthView
from sentry.http import safe_urlopen, safe_urlread
from sentry.models.authidentity import AuthIdentity
from sentry.utils.http import absolute_uri

ERR_INVALID_STATE = "An error occurred while validating your request."
Expand Down Expand Up @@ -149,7 +150,7 @@ def get_auth_pipeline(self):
def get_refresh_token_url(self) -> str:
raise NotImplementedError

def get_refresh_token_params(self, refresh_token):
def get_refresh_token_params(self, refresh_token: str) -> dict[str, str]:
return {
"client_id": self.get_client_id(),
"client_secret": self.get_client_secret(),
Expand Down Expand Up @@ -186,7 +187,7 @@ def update_identity(self, new_data, current_data):
new_data.setdefault("refresh_token", current_data["refresh_token"])
return new_data

def refresh_identity(self, auth_identity):
def refresh_identity(self, auth_identity: AuthIdentity) -> None:
refresh_token = auth_identity.data.get("refresh_token")

if not refresh_token:
Expand Down
3 changes: 2 additions & 1 deletion src/sentry/auth/providers/saml2/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sentry.auth.exceptions import IdentityNotValid
from sentry.auth.provider import Provider
from sentry.auth.view import AuthView
from sentry.models.authidentity import AuthIdentity
from sentry.models.authprovider import AuthProvider
from sentry.models.organization import OrganizationStatus
from sentry.models.organizationmapping import OrganizationMapping
Expand Down Expand Up @@ -308,7 +309,7 @@ def build_identity(self, state):
"name": name,
}

def refresh_identity(self, auth_identity):
def refresh_identity(self, auth_identity: AuthIdentity) -> None:
# Nothing to refresh
return

Expand Down
4 changes: 3 additions & 1 deletion src/sentry/identity/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
import logging
from typing import Any

from sentry.pipeline import PipelineProvider
from sentry.users.models.identity import Identity


class Provider(PipelineProvider, abc.ABC):
Expand Down Expand Up @@ -50,7 +52,7 @@ def update_identity(self, new_data, current_data):
"""
return new_data

def refresh_identity(self, auth_identity, *args, **kwargs):
def refresh_identity(self, identity: Identity, **kwargs: Any) -> Identity:
"""
Updates the AuthIdentity with any changes from upstream. The primary
example of a change would be signalling this identity is no longer
Expand Down
11 changes: 8 additions & 3 deletions src/sentry/identity/gitlab/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import logging
from typing import Any

import orjson

Expand All @@ -7,7 +10,7 @@
from sentry.http import safe_urlopen, safe_urlread
from sentry.identity.oauth2 import OAuth2Provider
from sentry.identity.services.identity import identity_service
from sentry.identity.services.identity.model import RpcIdentity
from sentry.users.models.identity import Identity
from sentry.utils.http import absolute_uri

logger = logging.getLogger("sentry.integration.gitlab")
Expand Down Expand Up @@ -75,7 +78,9 @@ def build_identity(self, data):
"data": self.get_oauth_data(data),
}

def get_refresh_token_params(self, refresh_token: str, identity: RpcIdentity):
def get_refresh_token_params(
self, refresh_token: str, identity: Identity, **kwargs: Any
) -> dict[str, str]:
client_id = identity.data.get("client_id")
client_secret = identity.data.get("client_secret")

Expand All @@ -87,7 +92,7 @@ def get_refresh_token_params(self, refresh_token: str, identity: RpcIdentity):
"client_secret": client_secret,
}

def refresh_identity(self, identity: RpcIdentity, *args, **kwargs):
def refresh_identity(self, identity: Identity, **kwargs: Any) -> Identity:
refresh_token = identity.data.get("refresh_token")
refresh_token_url = kwargs.get("refresh_token_url")

Expand Down
19 changes: 8 additions & 11 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import secrets
from time import time
from typing import Any
from urllib.parse import parse_qsl, urlencode

import orjson
Expand All @@ -21,6 +22,7 @@
)
from sentry.pipeline import PipelineView
from sentry.shared_integrations.exceptions import ApiError
from sentry.users.models.identity import Identity
from sentry.utils.http import absolute_uri

from .base import Provider
Expand Down Expand Up @@ -112,13 +114,10 @@ def get_pipeline_views(self) -> list[PipelineView]:
),
]

def get_refresh_token_params(self, refresh_token, *args, **kwargs):
return {
"client_id": self.get_client_id(),
"client_secret": self.get_client_secret(),
"grant_type": "refresh_token",
"refresh_token": refresh_token,
}
def get_refresh_token_params(
self, refresh_token: str, identity: Identity, **kwargs: Any
) -> dict[str, str]:
raise NotImplementedError

def get_oauth_data(self, payload):
data = {"access_token": payload["access_token"]}
Expand Down Expand Up @@ -185,15 +184,13 @@ def handle_refresh_error(self, req, payload):
)
raise ApiError(formatted_error)

def refresh_identity(self, identity, *args, **kwargs):
def refresh_identity(self, identity: Identity, **kwargs: Any) -> Identity:
refresh_token = identity.data.get("refresh_token")

if not refresh_token:
raise IdentityNotValid("Missing refresh token")

# XXX(meredith): This is used in VSTS's `get_refresh_token_params`
kwargs["identity"] = identity
data = self.get_refresh_token_params(refresh_token, *args, **kwargs)
data = self.get_refresh_token_params(refresh_token, identity, **kwargs)

req = safe_urlopen(
url=self.get_refresh_token_url(), headers=self.get_refresh_token_headers(), data=data
Expand Down
9 changes: 6 additions & 3 deletions src/sentry/identity/providers/dummy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
__all__ = ["DummyProvider"]
from typing import Any

from django.http import HttpResponse
from rest_framework.request import Request

from sentry.identity.base import Provider
from sentry.pipeline import PipelineView
from sentry.users.models.identity import Identity

__all__ = ("DummyProvider",)


class AskEmail(PipelineView):
Expand All @@ -28,5 +31,5 @@ def get_pipeline_views(self) -> list[PipelineView]:
def build_identity(self, state):
return {"id": state["email"], "email": state["email"], "name": "Dummy"}

def refresh_identity(self, auth_identity, *args, **kwargs):
pass
def refresh_identity(self, identity: Identity, **kwargs: Any) -> Identity:
return identity
17 changes: 12 additions & 5 deletions src/sentry/identity/vsts/provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

from typing import Any
from urllib.parse import parse_qsl

import orjson
from django.core.exceptions import PermissionDenied
from rest_framework.request import Request
Expand All @@ -6,6 +11,7 @@
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView, OAuth2Provider, record_event
from sentry.integrations.utils.metrics import IntegrationPipelineViewType
from sentry.pipeline.views.base import PipelineView
from sentry.users.models.identity import Identity
from sentry.utils.http import absolute_uri


Expand Down Expand Up @@ -72,8 +78,9 @@ def get_pipeline_views(self) -> list[PipelineView]:
def get_refresh_token_headers(self):
return {"Content-Type": "application/x-www-form-urlencoded", "Content-Length": "1654"}

def get_refresh_token_params(self, refresh_token, *args, **kwargs):
identity = kwargs["identity"]
def get_refresh_token_params(
self, refresh_token: str, identity: Identity, **kwargs: Any
) -> dict[str, str]:
client_secret = options.get("vsts.client-secret")

# The token refresh flow does not operate within a pipeline in the same way
Expand Down Expand Up @@ -117,8 +124,6 @@ def build_identity(self, data):

class VSTSOAuth2CallbackView(OAuth2CallbackView):
def exchange_token(self, request: Request, pipeline, code):
from urllib.parse import parse_qsl

from sentry.http import safe_urlopen, safe_urlread
from sentry.utils.http import absolute_uri

Expand Down Expand Up @@ -184,7 +189,9 @@ def get_pipeline_views(self):
def get_refresh_token_headers(self):
return {"Content-Type": "application/x-www-form-urlencoded", "Content-Length": "1654"}

def get_refresh_token_params(self, refresh_token, *args, **kwargs):
def get_refresh_token_params(
self, refresh_token: str, identity: Identity, **kwargs: Any
) -> dict[str, str]:
# TODO(iamrajjoshi): Fix vsts-limited here
# Note: ignoring the below from the original provider
# # If "vso.code" is missing from the identity.scopes, we know that we installed
Expand Down

0 comments on commit 19672f8

Please sign in to comment.