Skip to content

Commit 2487800

Browse files
fix: enable self signed jwt for grpc (#427)
PiperOrigin-RevId: 386504689 Source-Link: googleapis/googleapis@762094a Source-Link: googleapis/googleapis-gen@6bfc480
1 parent d842233 commit 2487800

File tree

6 files changed

+66
-39
lines changed

6 files changed

+66
-39
lines changed

google/cloud/spanner_admin_database_v1/services/database_admin/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ def __init__(
435435
client_cert_source_for_mtls=client_cert_source_func,
436436
quota_project_id=client_options.quota_project_id,
437437
client_info=client_info,
438+
always_use_jwt_access=(
439+
Transport == type(self).get_transport_class("grpc")
440+
or Transport == type(self).get_transport_class("grpc_asyncio")
441+
),
438442
)
439443

440444
def list_databases(

google/cloud/spanner_admin_instance_v1/services/instance_admin/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,10 @@ def __init__(
381381
client_cert_source_for_mtls=client_cert_source_func,
382382
quota_project_id=client_options.quota_project_id,
383383
client_info=client_info,
384+
always_use_jwt_access=(
385+
Transport == type(self).get_transport_class("grpc")
386+
or Transport == type(self).get_transport_class("grpc_asyncio")
387+
),
384388
)
385389

386390
def list_instance_configs(

google/cloud/spanner_v1/services/spanner/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,10 @@ def __init__(
368368
client_cert_source_for_mtls=client_cert_source_func,
369369
quota_project_id=client_options.quota_project_id,
370370
client_info=client_info,
371+
always_use_jwt_access=(
372+
Transport == type(self).get_transport_class("grpc")
373+
or Transport == type(self).get_transport_class("grpc_asyncio")
374+
),
371375
)
372376

373377
def create_session(

tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,26 +138,14 @@ def test_database_admin_client_from_service_account_info(client_class):
138138
assert client.transport._host == "spanner.googleapis.com:443"
139139

140140

141-
@pytest.mark.parametrize(
142-
"client_class", [DatabaseAdminClient, DatabaseAdminAsyncClient,]
143-
)
144-
def test_database_admin_client_service_account_always_use_jwt(client_class):
145-
with mock.patch.object(
146-
service_account.Credentials, "with_always_use_jwt_access", create=True
147-
) as use_jwt:
148-
creds = service_account.Credentials(None, None, None)
149-
client = client_class(credentials=creds)
150-
use_jwt.assert_not_called()
151-
152-
153141
@pytest.mark.parametrize(
154142
"transport_class,transport_name",
155143
[
156144
(transports.DatabaseAdminGrpcTransport, "grpc"),
157145
(transports.DatabaseAdminGrpcAsyncIOTransport, "grpc_asyncio"),
158146
],
159147
)
160-
def test_database_admin_client_service_account_always_use_jwt_true(
148+
def test_database_admin_client_service_account_always_use_jwt(
161149
transport_class, transport_name
162150
):
163151
with mock.patch.object(
@@ -167,6 +155,13 @@ def test_database_admin_client_service_account_always_use_jwt_true(
167155
transport = transport_class(credentials=creds, always_use_jwt_access=True)
168156
use_jwt.assert_called_once_with(True)
169157

158+
with mock.patch.object(
159+
service_account.Credentials, "with_always_use_jwt_access", create=True
160+
) as use_jwt:
161+
creds = service_account.Credentials(None, None, None)
162+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
163+
use_jwt.assert_not_called()
164+
170165

171166
@pytest.mark.parametrize(
172167
"client_class", [DatabaseAdminClient, DatabaseAdminAsyncClient,]
@@ -247,6 +242,7 @@ def test_database_admin_client_client_options(
247242
client_cert_source_for_mtls=None,
248243
quota_project_id=None,
249244
client_info=transports.base.DEFAULT_CLIENT_INFO,
245+
always_use_jwt_access=True,
250246
)
251247

252248
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -263,6 +259,7 @@ def test_database_admin_client_client_options(
263259
client_cert_source_for_mtls=None,
264260
quota_project_id=None,
265261
client_info=transports.base.DEFAULT_CLIENT_INFO,
262+
always_use_jwt_access=True,
266263
)
267264

268265
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -279,6 +276,7 @@ def test_database_admin_client_client_options(
279276
client_cert_source_for_mtls=None,
280277
quota_project_id=None,
281278
client_info=transports.base.DEFAULT_CLIENT_INFO,
279+
always_use_jwt_access=True,
282280
)
283281

284282
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -307,6 +305,7 @@ def test_database_admin_client_client_options(
307305
client_cert_source_for_mtls=None,
308306
quota_project_id="octopus",
309307
client_info=transports.base.DEFAULT_CLIENT_INFO,
308+
always_use_jwt_access=True,
310309
)
311310

312311

@@ -373,6 +372,7 @@ def test_database_admin_client_mtls_env_auto(
373372
client_cert_source_for_mtls=expected_client_cert_source,
374373
quota_project_id=None,
375374
client_info=transports.base.DEFAULT_CLIENT_INFO,
375+
always_use_jwt_access=True,
376376
)
377377

378378
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -406,6 +406,7 @@ def test_database_admin_client_mtls_env_auto(
406406
client_cert_source_for_mtls=expected_client_cert_source,
407407
quota_project_id=None,
408408
client_info=transports.base.DEFAULT_CLIENT_INFO,
409+
always_use_jwt_access=True,
409410
)
410411

411412
# Check the case client_cert_source and ADC client cert are not provided.
@@ -427,6 +428,7 @@ def test_database_admin_client_mtls_env_auto(
427428
client_cert_source_for_mtls=None,
428429
quota_project_id=None,
429430
client_info=transports.base.DEFAULT_CLIENT_INFO,
431+
always_use_jwt_access=True,
430432
)
431433

432434

@@ -457,6 +459,7 @@ def test_database_admin_client_client_options_scopes(
457459
client_cert_source_for_mtls=None,
458460
quota_project_id=None,
459461
client_info=transports.base.DEFAULT_CLIENT_INFO,
462+
always_use_jwt_access=True,
460463
)
461464

462465

@@ -487,6 +490,7 @@ def test_database_admin_client_client_options_credentials_file(
487490
client_cert_source_for_mtls=None,
488491
quota_project_id=None,
489492
client_info=transports.base.DEFAULT_CLIENT_INFO,
493+
always_use_jwt_access=True,
490494
)
491495

492496

@@ -506,6 +510,7 @@ def test_database_admin_client_client_options_from_dict():
506510
client_cert_source_for_mtls=None,
507511
quota_project_id=None,
508512
client_info=transports.base.DEFAULT_CLIENT_INFO,
513+
always_use_jwt_access=True,
509514
)
510515

511516

tests/unit/gapic/spanner_admin_instance_v1/test_instance_admin.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -131,26 +131,14 @@ def test_instance_admin_client_from_service_account_info(client_class):
131131
assert client.transport._host == "spanner.googleapis.com:443"
132132

133133

134-
@pytest.mark.parametrize(
135-
"client_class", [InstanceAdminClient, InstanceAdminAsyncClient,]
136-
)
137-
def test_instance_admin_client_service_account_always_use_jwt(client_class):
138-
with mock.patch.object(
139-
service_account.Credentials, "with_always_use_jwt_access", create=True
140-
) as use_jwt:
141-
creds = service_account.Credentials(None, None, None)
142-
client = client_class(credentials=creds)
143-
use_jwt.assert_not_called()
144-
145-
146134
@pytest.mark.parametrize(
147135
"transport_class,transport_name",
148136
[
149137
(transports.InstanceAdminGrpcTransport, "grpc"),
150138
(transports.InstanceAdminGrpcAsyncIOTransport, "grpc_asyncio"),
151139
],
152140
)
153-
def test_instance_admin_client_service_account_always_use_jwt_true(
141+
def test_instance_admin_client_service_account_always_use_jwt(
154142
transport_class, transport_name
155143
):
156144
with mock.patch.object(
@@ -160,6 +148,13 @@ def test_instance_admin_client_service_account_always_use_jwt_true(
160148
transport = transport_class(credentials=creds, always_use_jwt_access=True)
161149
use_jwt.assert_called_once_with(True)
162150

151+
with mock.patch.object(
152+
service_account.Credentials, "with_always_use_jwt_access", create=True
153+
) as use_jwt:
154+
creds = service_account.Credentials(None, None, None)
155+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
156+
use_jwt.assert_not_called()
157+
163158

164159
@pytest.mark.parametrize(
165160
"client_class", [InstanceAdminClient, InstanceAdminAsyncClient,]
@@ -240,6 +235,7 @@ def test_instance_admin_client_client_options(
240235
client_cert_source_for_mtls=None,
241236
quota_project_id=None,
242237
client_info=transports.base.DEFAULT_CLIENT_INFO,
238+
always_use_jwt_access=True,
243239
)
244240

245241
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -256,6 +252,7 @@ def test_instance_admin_client_client_options(
256252
client_cert_source_for_mtls=None,
257253
quota_project_id=None,
258254
client_info=transports.base.DEFAULT_CLIENT_INFO,
255+
always_use_jwt_access=True,
259256
)
260257

261258
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -272,6 +269,7 @@ def test_instance_admin_client_client_options(
272269
client_cert_source_for_mtls=None,
273270
quota_project_id=None,
274271
client_info=transports.base.DEFAULT_CLIENT_INFO,
272+
always_use_jwt_access=True,
275273
)
276274

277275
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -300,6 +298,7 @@ def test_instance_admin_client_client_options(
300298
client_cert_source_for_mtls=None,
301299
quota_project_id="octopus",
302300
client_info=transports.base.DEFAULT_CLIENT_INFO,
301+
always_use_jwt_access=True,
303302
)
304303

305304

@@ -366,6 +365,7 @@ def test_instance_admin_client_mtls_env_auto(
366365
client_cert_source_for_mtls=expected_client_cert_source,
367366
quota_project_id=None,
368367
client_info=transports.base.DEFAULT_CLIENT_INFO,
368+
always_use_jwt_access=True,
369369
)
370370

371371
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -399,6 +399,7 @@ def test_instance_admin_client_mtls_env_auto(
399399
client_cert_source_for_mtls=expected_client_cert_source,
400400
quota_project_id=None,
401401
client_info=transports.base.DEFAULT_CLIENT_INFO,
402+
always_use_jwt_access=True,
402403
)
403404

404405
# Check the case client_cert_source and ADC client cert are not provided.
@@ -420,6 +421,7 @@ def test_instance_admin_client_mtls_env_auto(
420421
client_cert_source_for_mtls=None,
421422
quota_project_id=None,
422423
client_info=transports.base.DEFAULT_CLIENT_INFO,
424+
always_use_jwt_access=True,
423425
)
424426

425427

@@ -450,6 +452,7 @@ def test_instance_admin_client_client_options_scopes(
450452
client_cert_source_for_mtls=None,
451453
quota_project_id=None,
452454
client_info=transports.base.DEFAULT_CLIENT_INFO,
455+
always_use_jwt_access=True,
453456
)
454457

455458

@@ -480,6 +483,7 @@ def test_instance_admin_client_client_options_credentials_file(
480483
client_cert_source_for_mtls=None,
481484
quota_project_id=None,
482485
client_info=transports.base.DEFAULT_CLIENT_INFO,
486+
always_use_jwt_access=True,
483487
)
484488

485489

@@ -499,6 +503,7 @@ def test_instance_admin_client_client_options_from_dict():
499503
client_cert_source_for_mtls=None,
500504
quota_project_id=None,
501505
client_info=transports.base.DEFAULT_CLIENT_INFO,
506+
always_use_jwt_access=True,
502507
)
503508

504509

tests/unit/gapic/spanner_v1/test_spanner.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,33 +120,28 @@ def test_spanner_client_from_service_account_info(client_class):
120120
assert client.transport._host == "spanner.googleapis.com:443"
121121

122122

123-
@pytest.mark.parametrize("client_class", [SpannerClient, SpannerAsyncClient,])
124-
def test_spanner_client_service_account_always_use_jwt(client_class):
125-
with mock.patch.object(
126-
service_account.Credentials, "with_always_use_jwt_access", create=True
127-
) as use_jwt:
128-
creds = service_account.Credentials(None, None, None)
129-
client = client_class(credentials=creds)
130-
use_jwt.assert_not_called()
131-
132-
133123
@pytest.mark.parametrize(
134124
"transport_class,transport_name",
135125
[
136126
(transports.SpannerGrpcTransport, "grpc"),
137127
(transports.SpannerGrpcAsyncIOTransport, "grpc_asyncio"),
138128
],
139129
)
140-
def test_spanner_client_service_account_always_use_jwt_true(
141-
transport_class, transport_name
142-
):
130+
def test_spanner_client_service_account_always_use_jwt(transport_class, transport_name):
143131
with mock.patch.object(
144132
service_account.Credentials, "with_always_use_jwt_access", create=True
145133
) as use_jwt:
146134
creds = service_account.Credentials(None, None, None)
147135
transport = transport_class(credentials=creds, always_use_jwt_access=True)
148136
use_jwt.assert_called_once_with(True)
149137

138+
with mock.patch.object(
139+
service_account.Credentials, "with_always_use_jwt_access", create=True
140+
) as use_jwt:
141+
creds = service_account.Credentials(None, None, None)
142+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
143+
use_jwt.assert_not_called()
144+
150145

151146
@pytest.mark.parametrize("client_class", [SpannerClient, SpannerAsyncClient,])
152147
def test_spanner_client_from_service_account_file(client_class):
@@ -215,6 +210,7 @@ def test_spanner_client_client_options(client_class, transport_class, transport_
215210
client_cert_source_for_mtls=None,
216211
quota_project_id=None,
217212
client_info=transports.base.DEFAULT_CLIENT_INFO,
213+
always_use_jwt_access=True,
218214
)
219215

220216
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -231,6 +227,7 @@ def test_spanner_client_client_options(client_class, transport_class, transport_
231227
client_cert_source_for_mtls=None,
232228
quota_project_id=None,
233229
client_info=transports.base.DEFAULT_CLIENT_INFO,
230+
always_use_jwt_access=True,
234231
)
235232

236233
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -247,6 +244,7 @@ def test_spanner_client_client_options(client_class, transport_class, transport_
247244
client_cert_source_for_mtls=None,
248245
quota_project_id=None,
249246
client_info=transports.base.DEFAULT_CLIENT_INFO,
247+
always_use_jwt_access=True,
250248
)
251249

252250
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -275,6 +273,7 @@ def test_spanner_client_client_options(client_class, transport_class, transport_
275273
client_cert_source_for_mtls=None,
276274
quota_project_id="octopus",
277275
client_info=transports.base.DEFAULT_CLIENT_INFO,
276+
always_use_jwt_access=True,
278277
)
279278

280279

@@ -337,6 +336,7 @@ def test_spanner_client_mtls_env_auto(
337336
client_cert_source_for_mtls=expected_client_cert_source,
338337
quota_project_id=None,
339338
client_info=transports.base.DEFAULT_CLIENT_INFO,
339+
always_use_jwt_access=True,
340340
)
341341

342342
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -370,6 +370,7 @@ def test_spanner_client_mtls_env_auto(
370370
client_cert_source_for_mtls=expected_client_cert_source,
371371
quota_project_id=None,
372372
client_info=transports.base.DEFAULT_CLIENT_INFO,
373+
always_use_jwt_access=True,
373374
)
374375

375376
# Check the case client_cert_source and ADC client cert are not provided.
@@ -391,6 +392,7 @@ def test_spanner_client_mtls_env_auto(
391392
client_cert_source_for_mtls=None,
392393
quota_project_id=None,
393394
client_info=transports.base.DEFAULT_CLIENT_INFO,
395+
always_use_jwt_access=True,
394396
)
395397

396398

@@ -417,6 +419,7 @@ def test_spanner_client_client_options_scopes(
417419
client_cert_source_for_mtls=None,
418420
quota_project_id=None,
419421
client_info=transports.base.DEFAULT_CLIENT_INFO,
422+
always_use_jwt_access=True,
420423
)
421424

422425

@@ -443,6 +446,7 @@ def test_spanner_client_client_options_credentials_file(
443446
client_cert_source_for_mtls=None,
444447
quota_project_id=None,
445448
client_info=transports.base.DEFAULT_CLIENT_INFO,
449+
always_use_jwt_access=True,
446450
)
447451

448452

@@ -460,6 +464,7 @@ def test_spanner_client_client_options_from_dict():
460464
client_cert_source_for_mtls=None,
461465
quota_project_id=None,
462466
client_info=transports.base.DEFAULT_CLIENT_INFO,
467+
always_use_jwt_access=True,
463468
)
464469

465470

0 commit comments

Comments
 (0)