diff --git a/google/api_core/universe.py b/google/api_core/universe.py new file mode 100644 index 00000000..35669642 --- /dev/null +++ b/google/api_core/universe.py @@ -0,0 +1,82 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for universe domain.""" + +from typing import Any, Optional + +DEFAULT_UNIVERSE = "googleapis.com" + + +class EmptyUniverseError(ValueError): + def __init__(self): + message = "Universe Domain cannot be an empty string." + super().__init__(message) + + +class UniverseMismatchError(ValueError): + def __init__(self, client_universe, credentials_universe): + message = ( + f"The configured universe domain ({client_universe}) does not match the universe domain " + f"found in the credentials ({credentials_universe}). " + "If you haven't configured the universe domain explicitly, " + f"`{DEFAULT_UNIVERSE}` is the default." + ) + super().__init__(message) + + +def determine_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] +) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the + "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise EmptyUniverseError + return universe_domain + + +def compare_domains(client_universe: str, credentials: Any) -> bool: + """Returns True iff the universe domains used by the client and credentials match. + + Args: + client_universe (str): The universe domain configured via the client options. + credentials Any: The credentials being used in the client. + + Returns: + bool: True iff client_universe matches the universe in credentials. + + Raises: + ValueError: when client_universe does not match the universe in credentials. + """ + credentials_universe = getattr(credentials, "universe_domain", DEFAULT_UNIVERSE) + + if client_universe != credentials_universe: + raise UniverseMismatchError(client_universe, credentials_universe) + return True diff --git a/tests/unit/test_universe.py b/tests/unit/test_universe.py new file mode 100644 index 00000000..214e00ac --- /dev/null +++ b/tests/unit/test_universe.py @@ -0,0 +1,63 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from google.api_core import universe + + +class _Fake_Credentials: + def __init__(self, universe_domain=None): + if universe_domain: + self.universe_domain = universe_domain + + +def test_determine_domain(): + domain_client = "foo.com" + domain_env = "bar.com" + + assert universe.determine_domain(domain_client, domain_env) == domain_client + assert universe.determine_domain(None, domain_env) == domain_env + assert universe.determine_domain(domain_client, None) == domain_client + assert universe.determine_domain(None, None) == universe.DEFAULT_UNIVERSE + + with pytest.raises(universe.EmptyUniverseError): + universe.determine_domain("", None) + + with pytest.raises(universe.EmptyUniverseError): + universe.determine_domain(None, "") + + +def test_compare_domains(): + fake_domain = "foo.com" + another_fake_domain = "bar.com" + + assert universe.compare_domains(universe.DEFAULT_UNIVERSE, _Fake_Credentials()) + assert universe.compare_domains(fake_domain, _Fake_Credentials(fake_domain)) + + with pytest.raises(universe.UniverseMismatchError) as excinfo: + universe.compare_domains( + universe.DEFAULT_UNIVERSE, _Fake_Credentials(fake_domain) + ) + assert str(excinfo.value).find(universe.DEFAULT_UNIVERSE) >= 0 + assert str(excinfo.value).find(fake_domain) >= 0 + + with pytest.raises(universe.UniverseMismatchError) as excinfo: + universe.compare_domains(fake_domain, _Fake_Credentials()) + assert str(excinfo.value).find(fake_domain) >= 0 + assert str(excinfo.value).find(universe.DEFAULT_UNIVERSE) >= 0 + + with pytest.raises(universe.UniverseMismatchError) as excinfo: + universe.compare_domains(fake_domain, _Fake_Credentials(another_fake_domain)) + assert str(excinfo.value).find(fake_domain) >= 0 + assert str(excinfo.value).find(another_fake_domain) >= 0