Skip to content

Commit 7b19175

Browse files
Jonathan Glasertimm4205
authored andcommitted
feat: Add idp_partition connection option
1 parent 5e28b3a commit 7b19175

12 files changed

+250
-15
lines changed

redshift_connector/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def connect(
153153
client_secret: typing.Optional[str] = None,
154154
partner_sp_id: typing.Optional[str] = None,
155155
idp_response_timeout: typing.Optional[int] = None,
156+
idp_partition: typing.Optional[str] = None,
156157
listen_port: typing.Optional[int] = None,
157158
login_to_rp: typing.Optional[str] = None,
158159
login_url: typing.Optional[str] = None,
@@ -334,6 +335,7 @@ def connect(
334335
info.put("idc_region", idc_region)
335336
info.put("identity_namespace", identity_namespace)
336337
info.put("idp_host", idp_host)
338+
info.put("idp_partition", idp_partition)
337339
info.put("idp_response_timeout", idp_response_timeout)
338340
info.put("idp_tenant", idp_tenant)
339341
info.put("issuer_url", issuer_url)

redshift_connector/plugin/azure_credentials_provider.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import typing
44

55
from redshift_connector.error import InterfaceError
6+
from redshift_connector.plugin.azure_utils import validate_idp_partition
67
from redshift_connector.plugin.credential_provider_constants import azure_headers
8+
from redshift_connector.plugin.plugin_utils import get_microsoft_idp_host
79
from redshift_connector.plugin.saml_credentials_provider import SamlCredentialsProvider
810
from redshift_connector.redshift_property import RedshiftProperty
911

@@ -23,6 +25,7 @@ def __init__(self: "AzureCredentialsProvider") -> None:
2325
self.idp_tenant: typing.Optional[str] = None
2426
self.client_secret: typing.Optional[str] = None
2527
self.client_id: typing.Optional[str] = None
28+
self.idp_partition: typing.Optional[str] = None
2629

2730
# method to grab the field parameters specified by end user.
2831
# This method adds to it Azure specific parameters.
@@ -34,6 +37,9 @@ def add_parameter(self: "AzureCredentialsProvider", info: RedshiftProperty) -> N
3437
self.client_secret = info.client_secret
3538
# The value of parameter client_id.
3639
self.client_id = info.client_id
40+
41+
# Validate and set idp_partition
42+
self.idp_partition = validate_idp_partition(info.idp_partition)
3743

3844
# Required method to grab the SAML Response. Used in base class to refresh temporary credentials.
3945
def get_saml_assertion(self: "AzureCredentialsProvider") -> str:
@@ -63,7 +69,10 @@ def azure_oauth_based_authentication(self: "AzureCredentialsProvider") -> str:
6369
import requests
6470

6571
# endpoint to connect with Microsoft Azure to get SAML Assertion token
66-
url: str = "https://login.microsoftonline.com/{tenant}/oauth2/token".format(tenant=self.idp_tenant)
72+
url: str = "https://{host}/{tenant}/oauth2/token".format(
73+
host=get_microsoft_idp_host(self.idp_partition),
74+
tenant=self.idp_tenant
75+
)
6776
_logger.debug("Uri: %s", url)
6877
self.validate_url(url)
6978

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Shared utilities for Azure credential providers."""
2+
from redshift_connector.error import InterfaceError
3+
import typing
4+
5+
6+
def validate_idp_partition(idp_partition: typing.Optional[str]) -> typing.Optional[str]:
7+
"""Validate idp_partition parameter and return normalized value."""
8+
if idp_partition is not None:
9+
if not isinstance(idp_partition, str):
10+
raise InterfaceError("idp_partition must be a string")
11+
# Validate against allowed values
12+
valid_partitions = ["", "us-gov", "cn"]
13+
normalized = idp_partition.strip().lower()
14+
if normalized not in valid_partitions:
15+
raise InterfaceError(f"idp_partition must be one of: {', '.join(repr(p) if p else 'empty string' for p in valid_partitions)}")
16+
return idp_partition

redshift_connector/plugin/browser_azure_credentials_provider.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import typing
77

88
from redshift_connector.error import InterfaceError
9+
from redshift_connector.plugin.azure_utils import validate_idp_partition
910
from redshift_connector.plugin.credential_provider_constants import azure_headers
11+
from redshift_connector.plugin.plugin_utils import get_microsoft_idp_host
1012
from redshift_connector.plugin.saml_credentials_provider import SamlCredentialsProvider
1113
from redshift_connector.redshift_property import RedshiftProperty
1214

@@ -28,6 +30,7 @@ def __init__(self: "BrowserAzureCredentialsProvider") -> None:
2830

2931
self.idp_response_timeout: int = 120
3032
self.listen_port: int = 0
33+
self.idp_partition: typing.Optional[str] = None
3134

3235
self.redirectUri: typing.Optional[str] = None
3336

@@ -48,6 +51,9 @@ def add_parameter(self: "BrowserAzureCredentialsProvider", info: RedshiftPropert
4851
self.idp_tenant = info.idp_tenant
4952
# The value of parameter client_id.
5053
self.client_id = info.client_id
54+
55+
# Validate and set idp_partition
56+
self.idp_partition = validate_idp_partition(info.idp_partition)
5157

5258
self.idp_response_timeout = info.idp_response_timeout
5359

@@ -110,7 +116,10 @@ def fetch_saml_response(self: "BrowserAzureCredentialsProvider", token):
110116
_logger.debug("BrowserAzureCredentialsProvider.fetch_saml_response")
111117
import requests
112118

113-
url: str = "https://login.microsoftonline.com/{tenant}/oauth2/token".format(tenant=self.idp_tenant)
119+
url: str = "https://{host}/{tenant}/oauth2/token".format(
120+
host=get_microsoft_idp_host(self.idp_partition),
121+
tenant=self.idp_tenant
122+
)
114123
# headers to pass with POST request
115124
headers: typing.Dict[str, str] = azure_headers
116125
self.validate_url(url)
@@ -237,16 +246,21 @@ def run_server(
237246
def open_browser(self: "BrowserAzureCredentialsProvider", state: str) -> None:
238247
_logger.debug("BrowserAzureCredentialsProvider.open_browser")
239248
import webbrowser
240-
241-
url: str = (
242-
"https://login.microsoftonline.com/{tenant}"
243-
"/oauth2/authorize"
244-
"?scope=openid"
245-
"&response_type=code"
246-
"&response_mode=form_post"
247-
"&client_id={id}"
248-
"&redirect_uri={uri}"
249-
"&state={state}".format(tenant=self.idp_tenant, id=self.client_id, uri=self.redirectUri, state=state)
249+
from urllib.parse import quote, urlencode
250+
251+
# For query parameters, use urlencode for the entire query string
252+
query_params = {
253+
'scope': 'openid',
254+
'response_type': 'code',
255+
'response_mode': 'form_post',
256+
'client_id': self.client_id,
257+
'redirect_uri': self.redirectUri,
258+
'state': state
259+
}
260+
url: str = "https://{host}/{tenant}/oauth2/authorize?{query}".format(
261+
host=get_microsoft_idp_host(self.idp_partition),
262+
tenant=self.idp_tenant,
263+
query=urlencode(query_params)
250264
)
251265
self.validate_url(url)
252266
webbrowser.open(url)

redshift_connector/plugin/browser_azure_oauth2_credentials_provider.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from enum import Enum
55

66
from redshift_connector.error import InterfaceError
7+
from redshift_connector.plugin.azure_utils import validate_idp_partition
78
from redshift_connector.plugin.credential_provider_constants import azure_headers
89
from redshift_connector.plugin.jwt_credentials_provider import JwtCredentialsProvider
10+
from redshift_connector.plugin.plugin_utils import get_microsoft_idp_host
911
from redshift_connector.redshift_property import RedshiftProperty
1012

1113
if typing.TYPE_CHECKING:
@@ -35,15 +37,16 @@ class OAuthParamNames(Enum):
3537
SCOPE: str = "scope"
3638
RESPONSE_MODE: str = "response_mode"
3739
RESOURCE: str = "resource"
40+
IDP_PARTITION: str = "idp_partition"
3841

39-
MICROSOFT_IDP_HOST: str = "login.microsoftonline.com"
4042
CURRENT_INTERACTION_SCHEMA: str = "https"
4143

4244
def __init__(self: "BrowserAzureOAuth2CredentialsProvider") -> None:
4345
super().__init__()
4446
self.idp_tenant: typing.Optional[str] = None
4547
self.client_id: typing.Optional[str] = None
4648
self.scope: str = ""
49+
self.idp_partition: typing.Optional[str] = None
4750
self.idp_response_timeout: int = 120
4851
self.listen_port: int = 0
4952

@@ -55,6 +58,9 @@ def add_parameter(
5558
self.idp_tenant = info.idp_tenant
5659
self.client_id = info.client_id
5760
self.scope = info.scope
61+
62+
# Validate and set idp_partition
63+
self.idp_partition = validate_idp_partition(info.idp_partition)
5864

5965
if info.idp_response_timeout:
6066
self.idp_response_timeout = info.idp_response_timeout
@@ -207,7 +213,7 @@ def get_authorization_token_url(self, state: str) -> str:
207213
return urlunsplit(
208214
(
209215
BrowserAzureOAuth2CredentialsProvider.CURRENT_INTERACTION_SCHEMA,
210-
BrowserAzureOAuth2CredentialsProvider.MICROSOFT_IDP_HOST,
216+
get_microsoft_idp_host(self.idp_partition),
211217
"/{}/oauth2/v2.0/authorize".format(self.idp_tenant),
212218
encoded_params,
213219
"",
@@ -251,7 +257,7 @@ def get_jwt_post_request_url(self: "BrowserAzureOAuth2CredentialsProvider") -> s
251257
"""
252258
return "{}://{}{}".format(
253259
BrowserAzureOAuth2CredentialsProvider.CURRENT_INTERACTION_SCHEMA,
254-
BrowserAzureOAuth2CredentialsProvider.MICROSOFT_IDP_HOST,
260+
get_microsoft_idp_host(self.idp_partition),
255261
"/{}/oauth2/v2.0/token".format(self.idp_tenant),
256262
)
257263

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from redshift_connector.error import InterfaceError
2+
import typing
3+
4+
# Microsoft IdP host constants
5+
MICROSOFT_COMMERCIAL_HOST = "login.microsoftonline.com"
6+
MICROSOFT_IDP_HOSTS = {
7+
"us-gov": "login.microsoftonline.us",
8+
"cn": "login.chinacloudapi.cn"
9+
}
10+
11+
12+
def get_microsoft_idp_host(idp_partition: typing.Optional[str] = None) -> str:
13+
"""Returns the appropriate Microsoft IDP host based on the idp_partition value."""
14+
from redshift_connector.plugin.azure_utils import validate_idp_partition
15+
16+
validated_partition = validate_idp_partition(idp_partition)
17+
if not validated_partition or not validated_partition.strip():
18+
return MICROSOFT_COMMERCIAL_HOST
19+
20+
partition = validated_partition.strip().lower()
21+
if partition in MICROSOFT_IDP_HOSTS:
22+
return MICROSOFT_IDP_HOSTS[partition]
23+
else:
24+
supported_values = list(MICROSOFT_IDP_HOSTS.keys())
25+
raise InterfaceError("Invalid IdP partition: '{}' (normalized: '{}'). Supported values are: {}, or empty/None for commercial cloud.".format(
26+
idp_partition, partition, ", ".join("'{}'".format(v) for v in supported_values)
27+
))

redshift_connector/redshift_property.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def __init__(self: "RedshiftProperty", **kwargs):
6767
self.idp_host: typing.Optional[str] = None
6868
# timeout for authentication via Browser IDP
6969
self.idp_response_timeout: int = 120
70+
# The IdP partition for multi-tenant IdP configurations (Azure AD: us-gov, cn)
71+
self.idp_partition: typing.Optional[str] = None
7072
# The Azure AD tenant ID for your Redshift application.Only used for Azure AD.
7173
self.idp_tenant: typing.Optional[str] = None
7274
# The port used by an IdP (identity provider).

test/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,25 @@ def redshift_browser_idc() -> typing.Dict[str, typing.Union[str, typing.Optional
360360
return db_connect
361361

362362

363+
@pytest.fixture(scope="class")
364+
def redshift_browser_azure_oauth2() -> typing.Dict[str, typing.Union[str, typing.Optional[bool], int]]:
365+
db_connect = {
366+
"host": conf.get("redshift-browser-azure-oauth2", "host", fallback=None),
367+
"port": conf.getint("redshift-browser-azure-oauth2", "port", fallback=5439),
368+
"database": conf.get("redshift-browser-azure-oauth2", "database", fallback="dev"),
369+
"iam": conf.getboolean("redshift-browser-azure-oauth2", "iam", fallback=True),
370+
"credentials_provider": conf.get(
371+
"redshift-browser-azure-oauth2", "credentials_provider", fallback="BrowserAzureOAuth2CredentialsProvider"
372+
),
373+
"idp_tenant": conf.get("redshift-browser-azure-oauth2", "idp_tenant", fallback=None),
374+
"client_id": conf.get("redshift-browser-azure-oauth2", "client_id", fallback=None),
375+
"scope": conf.get("redshift-browser-azure-oauth2", "scope", fallback=None),
376+
"idp_partition": conf.get("redshift-browser-azure-oauth2", "idp_partition", fallback=None),
377+
"idp_response_timeout": conf.getint("redshift-browser-azure-oauth2", "idp_response_timeout", fallback=120),
378+
}
379+
return db_connect
380+
381+
363382
@pytest.fixture
364383
def con(request, db_kwargs) -> redshift_connector.Connection:
365384
conn: redshift_connector.Connection = redshift_connector.connect(**db_kwargs)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Common test utilities for Azure credential providers."""
2+
import pytest
3+
from redshift_connector import InterfaceError
4+
from redshift_connector.plugin.azure_utils import validate_idp_partition
5+
from redshift_connector.plugin.plugin_utils import get_microsoft_idp_host
6+
7+
8+
@pytest.mark.parametrize("partition", ["", " ", None])
9+
def test_get_microsoft_idp_host_empty_partition_returns_commercial_host(partition) -> None:
10+
"""Test that empty/None partitions return commercial host."""
11+
assert get_microsoft_idp_host(partition) == "login.microsoftonline.com"
12+
13+
14+
@pytest.mark.parametrize("partition", ["us-gov", "US-GOV", "Us-gov "])
15+
def test_get_microsoft_idp_host_us_gov_partition(partition) -> None:
16+
"""Test that us-gov partitions return US government host."""
17+
assert get_microsoft_idp_host(partition) == "login.microsoftonline.us"
18+
19+
20+
@pytest.mark.parametrize("partition", ["cn", "CN", "Cn "])
21+
def test_get_microsoft_idp_host_china_partition(partition) -> None:
22+
"""Test that cn partitions return China host."""
23+
assert get_microsoft_idp_host(partition) == "login.chinacloudapi.cn"
24+
25+
26+
def test_get_microsoft_idp_host_invalid_partition_throws_error() -> None:
27+
"""Test that invalid partitions raise InterfaceError."""
28+
with pytest.raises(InterfaceError, match="idp_partition must be one of"):
29+
get_microsoft_idp_host("random_partition")
30+
31+
32+
@pytest.mark.parametrize("partition", ["", "us-gov", "cn", None])
33+
def test_validate_idp_partition_valid_values(partition) -> None:
34+
"""Test that valid partition values are accepted."""
35+
result = validate_idp_partition(partition)
36+
assert result == partition
37+
38+
39+
def test_validate_idp_partition_invalid_type() -> None:
40+
"""Test that non-string partition values raise InterfaceError."""
41+
with pytest.raises(InterfaceError, match="idp_partition must be a string"):
42+
validate_idp_partition(123)
43+
44+
45+
def test_validate_idp_partition_invalid_value() -> None:
46+
"""Test that invalid partition values raise InterfaceError."""
47+
with pytest.raises(InterfaceError, match="idp_partition must be one of"):
48+
validate_idp_partition("invalid_partition")

test/unit/plugin/test_azure_credentials_provider.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
from redshift_connector import InterfaceError, RedshiftProperty
88
from redshift_connector.plugin import AzureCredentialsProvider
99
from redshift_connector.plugin.credential_provider_constants import azure_headers
10+
from redshift_connector.plugin.plugin_utils import get_microsoft_idp_host
11+
12+
# Import common Azure tests
13+
from test.unit.plugin.test_azure_common import (
14+
test_get_microsoft_idp_host_empty_partition_returns_commercial_host,
15+
test_get_microsoft_idp_host_us_gov_partition,
16+
test_get_microsoft_idp_host_china_partition,
17+
test_get_microsoft_idp_host_invalid_partition_throws_error
18+
)
1019

1120

1221
def make_valid_azure_credentials_provider() -> typing.Tuple[AzureCredentialsProvider, RedshiftProperty]:

0 commit comments

Comments
 (0)