diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index 41f4a05bb..c239beed1 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -87,6 +87,7 @@ def __init__( granted_scopes=None, trust_boundary=None, universe_domain=_DEFAULT_UNIVERSE_DOMAIN, + account=None, ): """ Args: @@ -131,6 +132,7 @@ def __init__( trust_boundary (str): String representation of trust boundary meta. universe_domain (Optional[str]): The universe domain. The default universe domain is googleapis.com. + account (Optional[str]): The account associated with the credential. """ super(Credentials, self).__init__() self.token = token @@ -149,6 +151,7 @@ def __init__( self._enable_reauth_refresh = enable_reauth_refresh self._trust_boundary = trust_boundary self._universe_domain = universe_domain or _DEFAULT_UNIVERSE_DOMAIN + self._account = account or "" def __getstate__(self): """A __getstate__ method must exist for the __setstate__ to be called @@ -189,6 +192,7 @@ def __setstate__(self, d): self._refresh_handler = None self._refresh_worker = None self._use_non_blocking_refresh = d.get("_use_non_blocking_refresh", False) + self._account = d.get("_account", "") @property def refresh_token(self): @@ -268,6 +272,11 @@ def refresh_handler(self, value): raise TypeError("The provided refresh_handler is not a callable or None.") self._refresh_handler = value + @property + def account(self): + """str: The user account associated with the credential. If the account is unknown an empty string is returned.""" + return self._account + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): @@ -286,6 +295,7 @@ def with_quota_project(self, quota_project_id): enable_reauth_refresh=self._enable_reauth_refresh, trust_boundary=self._trust_boundary, universe_domain=self._universe_domain, + account=self._account, ) @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) @@ -306,6 +316,35 @@ def with_token_uri(self, token_uri): enable_reauth_refresh=self._enable_reauth_refresh, trust_boundary=self._trust_boundary, universe_domain=self._universe_domain, + account=self._account, + ) + + def with_account(self, account): + """Returns a copy of these credentials with a modified account. + + Args: + account (str): The account to set + + Returns: + google.oauth2.credentials.Credentials: A new credentials instance. + """ + + return self.__class__( + self.token, + refresh_token=self.refresh_token, + id_token=self.id_token, + token_uri=self._token_uri, + client_id=self.client_id, + client_secret=self.client_secret, + scopes=self.scopes, + default_scopes=self.default_scopes, + granted_scopes=self.granted_scopes, + quota_project_id=self.quota_project_id, + rapt_token=self.rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, + trust_boundary=self._trust_boundary, + universe_domain=self._universe_domain, + account=account, ) @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) @@ -326,6 +365,7 @@ def with_universe_domain(self, universe_domain): enable_reauth_refresh=self._enable_reauth_refresh, trust_boundary=self._trust_boundary, universe_domain=universe_domain, + account=self._account, ) def _metric_header_for_usage(self): @@ -474,6 +514,7 @@ def from_authorized_user_info(cls, info, scopes=None): rapt_token=info.get("rapt_token"), # may not exist trust_boundary=info.get("trust_boundary"), # may not exist universe_domain=info.get("universe_domain"), # may not exist + account=info.get("account", ""), # may not exist ) @classmethod @@ -518,6 +559,7 @@ def to_json(self, strip=None): "scopes": self.scopes, "rapt_token": self.rapt_token, "universe_domain": self._universe_domain, + "account": self._account, } if self.expiry: # flatten expiry timestamp prep["expiry"] = self.expiry.isoformat() + "Z" diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index 61417809c..78e375c05 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index 7516fe22e..216641946 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -793,6 +793,12 @@ def test_with_universe_domain(self): new_creds = creds.with_universe_domain("dummy_universe.com") assert new_creds.universe_domain == "dummy_universe.com" + def test_with_account(self): + creds = credentials.Credentials(token="token") + assert creds.account == "" + new_creds = creds.with_account("mock@example.com") + assert new_creds.account == "mock@example.com" + def test_with_token_uri(self): info = AUTH_USER_INFO.copy() @@ -888,6 +894,7 @@ def test_to_json(self): assert json_asdict.get("client_secret") == creds.client_secret assert json_asdict.get("expiry") == info["expiry"] assert json_asdict.get("universe_domain") == creds.universe_domain + assert json_asdict.get("account") == creds.account # Test with a `strip` arg json_output = creds.to_json(strip=["client_secret"])