Skip to content

Commit b1e9dee

Browse files
feat: add universe domain support to Connector (TPC) (#1045)
1 parent a9a1d0a commit b1e9dee

File tree

4 files changed

+105
-16
lines changed

4 files changed

+105
-16
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
logger = logging.getLogger(name=__name__)
4444

4545
ASYNC_DRIVERS = ["asyncpg"]
46+
_DEFAULT_SCHEME = "https://"
47+
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
48+
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
4649

4750

4851
class Connector:
@@ -58,6 +61,7 @@ def __init__(
5861
quota_project: Optional[str] = None,
5962
sqladmin_api_endpoint: Optional[str] = None,
6063
user_agent: Optional[str] = None,
64+
universe_domain: Optional[str] = None,
6165
) -> None:
6266
"""Initializes a Connector instance.
6367
@@ -90,6 +94,10 @@ def __init__(
9094
sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL
9195
Admin API endpoint. Defaults to "https://sqladmin.googleapis.com",
9296
this argument should only be used in development.
97+
98+
universe_domain (str): The universe domain for Cloud SQL API calls.
99+
Default: "googleapis.com".
100+
93101
"""
94102
# if event loop is given, use for background tasks
95103
if loop:
@@ -126,12 +134,36 @@ def __init__(
126134
self._timeout = timeout
127135
self._enable_iam_auth = enable_iam_auth
128136
self._quota_project = quota_project
129-
self._sqladmin_api_endpoint = sqladmin_api_endpoint
130137
self._user_agent = user_agent
131138
# if ip_type is str, convert to IPTypes enum
132139
if isinstance(ip_type, str):
133140
ip_type = IPTypes._from_str(ip_type)
134141
self._ip_type = ip_type
142+
self._universe_domain = universe_domain
143+
# construct service endpoint for Cloud SQL Admin API calls
144+
if not sqladmin_api_endpoint:
145+
self._sqladmin_api_endpoint = (
146+
_DEFAULT_SCHEME
147+
+ _SQLADMIN_HOST_TEMPLATE.format(universe_domain=self.universe_domain)
148+
)
149+
# otherwise if endpoint override is passed in use it
150+
else:
151+
self._sqladmin_api_endpoint = sqladmin_api_endpoint
152+
153+
# validate that the universe domain of the credentials matches the
154+
# universe domain of the service endpoint
155+
if self._credentials.universe_domain != self.universe_domain:
156+
raise ValueError(
157+
f"The configured universe domain ({self.universe_domain}) does "
158+
"not match the universe domain found in the credentials "
159+
f"({self._credentials.universe_domain}). If you haven't "
160+
"configured the universe domain explicitly, `googleapis.com` "
161+
"is the default."
162+
)
163+
164+
@property
165+
def universe_domain(self) -> str:
166+
return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN
135167

136168
def connect(
137169
self, instance_connection_string: str, driver: str, **kwargs: Any
@@ -371,6 +403,7 @@ async def create_async_connector(
371403
quota_project: Optional[str] = None,
372404
sqladmin_api_endpoint: Optional[str] = None,
373405
user_agent: Optional[str] = None,
406+
universe_domain: Optional[str] = None,
374407
) -> Connector:
375408
"""Helper function to create Connector object for asyncio connections.
376409

tests/unit/mocks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from cryptography.x509.oid import NameOID
3131
from google.auth.credentials import Credentials
3232

33+
from google.cloud.sql.connector.connector import _DEFAULT_UNIVERSE_DOMAIN
3334
from google.cloud.sql.connector.instance import ConnectionInfo
3435
from google.cloud.sql.connector.utils import generate_keys
3536
from google.cloud.sql.connector.utils import write_to_file
@@ -41,6 +42,7 @@ def __init__(
4142
) -> None:
4243
self.token = token
4344
self.expiry = expiry
45+
self._universe_domain = _DEFAULT_UNIVERSE_DOMAIN
4446

4547
@property
4648
def __class__(self) -> Credentials:
@@ -68,6 +70,11 @@ def expired(self) -> bool:
6870
return False
6971
return True
7072

73+
@property
74+
def universe_domain(self) -> str:
75+
"""The universe domain value."""
76+
return self._universe_domain
77+
7178
@property
7279
def valid(self) -> bool:
7380
"""Checks the validity of the credentials.

tests/unit/test_client.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,6 @@ async def test_CloudSQLClient_init_(fake_credentials: FakeCredentials) -> None:
7777
await client.close()
7878

7979

80-
async def test_CloudSQLClient_init_default_service_endpoint(
81-
fake_credentials: FakeCredentials,
82-
) -> None:
83-
"""
84-
Test to check whether the __init__ method of CloudSQLClient
85-
can correctly initialize the default service endpoint.
86-
"""
87-
driver = "pg8000"
88-
client = CloudSQLClient(None, "my-quota-project", fake_credentials, driver=driver)
89-
# verify base endpoint is set to proper default
90-
assert client._sqladmin_api_endpoint == "https://sqladmin.googleapis.com"
91-
# close client
92-
await client.close()
93-
94-
9580
@pytest.mark.asyncio
9681
async def test_CloudSQLClient_init_custom_user_agent(
9782
fake_credentials: FakeCredentials,

tests/unit/test_connector.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,67 @@ def test_Connector_close_called_multiple_times(fake_credentials: Credentials) ->
273273
assert connector._thread.is_alive() is False
274274
# call connector.close a second time
275275
connector.close()
276+
277+
278+
def test_default_universe_domain(fake_credentials: Credentials) -> None:
279+
"""Test that default universe domain and constructed service endpoint are
280+
formatted correctly.
281+
"""
282+
with Connector(credentials=fake_credentials) as connector:
283+
# test universe domain was not configured
284+
assert connector._universe_domain is None
285+
# test property and service endpoint construction
286+
assert connector.universe_domain == "googleapis.com"
287+
assert connector._sqladmin_api_endpoint == "https://sqladmin.googleapis.com"
288+
289+
290+
def test_configured_universe_domain_matches_GDU(fake_credentials: Credentials) -> None:
291+
"""Test that configured universe domain succeeds with matched GDU credentials."""
292+
universe_domain = "googleapis.com"
293+
with Connector(
294+
credentials=fake_credentials, universe_domain=universe_domain
295+
) as connector:
296+
# test universe domain was configured
297+
assert connector._universe_domain == universe_domain
298+
# test property and service endpoint construction
299+
assert connector.universe_domain == universe_domain
300+
assert connector._sqladmin_api_endpoint == f"https://sqladmin.{universe_domain}"
301+
302+
303+
def test_configured_universe_domain_matches_credentials(
304+
fake_credentials: Credentials,
305+
) -> None:
306+
"""Test that configured universe domain succeeds with matching universe
307+
domain credentials.
308+
"""
309+
universe_domain = "test-universe.test"
310+
# set fake credentials to be configured for the universe domain
311+
fake_credentials._universe_domain = universe_domain
312+
with Connector(
313+
credentials=fake_credentials, universe_domain=universe_domain
314+
) as connector:
315+
# test universe domain was configured
316+
assert connector._universe_domain == universe_domain
317+
# test property and service endpoint construction
318+
assert connector.universe_domain == universe_domain
319+
assert connector._sqladmin_api_endpoint == f"https://sqladmin.{universe_domain}"
320+
321+
322+
def test_configured_universe_domain_mismatched_credentials(
323+
fake_credentials: Credentials,
324+
) -> None:
325+
"""Test that configured universe domain errors with mismatched universe
326+
domain credentials.
327+
"""
328+
universe_domain = "test-universe.test"
329+
# credentials have GDU domain ("googleapis.com")
330+
with pytest.raises(ValueError) as exc_info:
331+
Connector(credentials=fake_credentials, universe_domain=universe_domain)
332+
err_msg = (
333+
f"The configured universe domain ({universe_domain}) does "
334+
"not match the universe domain found in the credentials "
335+
f"({fake_credentials.universe_domain}). If you haven't "
336+
"configured the universe domain explicitly, `googleapis.com` "
337+
"is the default."
338+
)
339+
assert exc_info.value.args[0] == err_msg

0 commit comments

Comments
 (0)