Skip to content

Commit de00447

Browse files
authored
feat: Add header back to the client (#2016)
* feat: add back header setting * . * . * . * . * . * . * . * . * . * . * .
1 parent 35db0fb commit de00447

File tree

5 files changed

+145
-9
lines changed

5 files changed

+145
-9
lines changed

google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import com.google.api.core.ApiFuture;
1919
import com.google.api.core.SettableApiFuture;
2020
import com.google.api.gax.batching.FlowController;
21+
import com.google.api.gax.rpc.FixedHeaderProvider;
2122
import com.google.auto.value.AutoValue;
2223
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.ProtoData;
2324
import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializtionError;
@@ -77,6 +78,11 @@ class ConnectionWorker implements AutoCloseable {
7778
*/
7879
private String streamName;
7980

81+
/*
82+
* The location of this connection.
83+
*/
84+
private String location = null;
85+
8086
/*
8187
* The proto schema of rows to write. This schema can change during multiplexing.
8288
*/
@@ -211,6 +217,7 @@ public static long getApiMaxRequestBytes() {
211217

212218
public ConnectionWorker(
213219
String streamName,
220+
String location,
214221
ProtoSchema writerSchema,
215222
long maxInflightRequests,
216223
long maxInflightBytes,
@@ -223,6 +230,9 @@ public ConnectionWorker(
223230
this.hasMessageInWaitingQueue = lock.newCondition();
224231
this.inflightReduced = lock.newCondition();
225232
this.streamName = streamName;
233+
if (location != null && !location.isEmpty()) {
234+
this.location = location;
235+
}
226236
this.maxRetryDuration = maxRetryDuration;
227237
if (writerSchema == null) {
228238
throw new StatusRuntimeException(
@@ -236,6 +246,18 @@ public ConnectionWorker(
236246
this.waitingRequestQueue = new LinkedList<AppendRequestAndResponse>();
237247
this.inflightRequestQueue = new LinkedList<AppendRequestAndResponse>();
238248
// Always recreate a client for connection worker.
249+
HashMap<String, String> newHeaders = new HashMap<>();
250+
newHeaders.putAll(clientSettings.toBuilder().getHeaderProvider().getHeaders());
251+
if (this.location == null) {
252+
newHeaders.put("x-goog-request-params", "write_stream=" + this.streamName);
253+
} else {
254+
newHeaders.put("x-goog-request-params", "write_location=" + this.location);
255+
}
256+
BigQueryWriteSettings stubSettings =
257+
clientSettings
258+
.toBuilder()
259+
.setHeaderProvider(FixedHeaderProvider.create(newHeaders))
260+
.build();
239261
this.client = BigQueryWriteClient.create(clientSettings);
240262

241263
this.appendThread =
@@ -297,6 +319,24 @@ public void run(Throwable finalStatus) {
297319

298320
/** Schedules the writing of rows at given offset. */
299321
ApiFuture<AppendRowsResponse> append(StreamWriter streamWriter, ProtoRows rows, long offset) {
322+
if (this.location != null && this.location != streamWriter.getLocation()) {
323+
throw new StatusRuntimeException(
324+
Status.fromCode(Code.INVALID_ARGUMENT)
325+
.withDescription(
326+
"StreamWriter with location "
327+
+ streamWriter.getLocation()
328+
+ " is scheduled to use a connection with location "
329+
+ this.location));
330+
} else if (this.location == null && streamWriter.getStreamName() != this.streamName) {
331+
// Location is null implies this is non-multiplexed connection.
332+
throw new StatusRuntimeException(
333+
Status.fromCode(Code.INVALID_ARGUMENT)
334+
.withDescription(
335+
"StreamWriter with stream name "
336+
+ streamWriter.getStreamName()
337+
+ " is scheduled to use a connection with stream name "
338+
+ this.streamName));
339+
}
300340
Preconditions.checkNotNull(streamWriter);
301341
AppendRowsRequest.Builder requestBuilder = AppendRowsRequest.newBuilder();
302342
requestBuilder.setProtoRows(
@@ -322,6 +362,10 @@ Boolean isUserClosed() {
322362
}
323363
}
324364

365+
String getWriteLocation() {
366+
return this.location;
367+
}
368+
325369
private ApiFuture<AppendRowsResponse> appendInternal(
326370
StreamWriter streamWriter, AppendRowsRequest message) {
327371
AppendRequestAndResponse requestWrapper = new AppendRequestAndResponse(message, streamWriter);

google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ private ConnectionWorker createOrReuseConnectionWorker(
288288
String streamReference = streamWriter.getStreamName();
289289
if (connectionWorkerPool.size() < currentMaxConnectionCount) {
290290
// Always create a new connection if we haven't reached current maximum.
291-
return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema());
291+
return createConnectionWorker(
292+
streamWriter.getStreamName(), streamWriter.getLocation(), streamWriter.getProtoSchema());
292293
} else {
293294
ConnectionWorker existingBestConnection =
294295
pickBestLoadConnection(
@@ -304,7 +305,10 @@ private ConnectionWorker createOrReuseConnectionWorker(
304305
if (currentMaxConnectionCount > settings.maxConnectionsPerRegion()) {
305306
currentMaxConnectionCount = settings.maxConnectionsPerRegion();
306307
}
307-
return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema());
308+
return createConnectionWorker(
309+
streamWriter.getStreamName(),
310+
streamWriter.getLocation(),
311+
streamWriter.getProtoSchema());
308312
} else {
309313
// Stick to the original connection if all the connections are overwhelmed.
310314
if (existingConnectionWorker != null) {
@@ -359,15 +363,16 @@ static ConnectionWorker pickBestLoadConnection(
359363
* a single stream reference. This is because createConnectionWorker(...) is called via
360364
* computeIfAbsent(...) which is at most once per key.
361365
*/
362-
private ConnectionWorker createConnectionWorker(String streamName, ProtoSchema writeSchema)
363-
throws IOException {
366+
private ConnectionWorker createConnectionWorker(
367+
String streamName, String location, ProtoSchema writeSchema) throws IOException {
364368
if (enableTesting) {
365369
// Though atomic integer is super lightweight, add extra if check in case adding future logic.
366370
testValueCreateConnectionCount.getAndIncrement();
367371
}
368372
ConnectionWorker connectionWorker =
369373
new ConnectionWorker(
370374
streamName,
375+
location,
371376
writeSchema,
372377
maxInflightRequests,
373378
maxInflightBytes,

google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ private StreamWriter(Builder builder) throws IOException {
208208
SingleConnectionOrConnectionPool.ofSingleConnection(
209209
new ConnectionWorker(
210210
builder.streamName,
211+
builder.location,
211212
builder.writerSchema,
212213
builder.maxInflightRequest,
213214
builder.maxInflightBytes,

google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ private StreamWriter getTestStreamWriter(String streamName) throws IOException {
430430
return StreamWriter.newBuilder(streamName)
431431
.setWriterSchema(createProtoSchema())
432432
.setTraceId(TEST_TRACE_ID)
433+
.setLocation("us")
433434
.setCredentialsProvider(NoCredentialsProvider.create())
434435
.setChannelProvider(serviceHelper.createChannelProvider())
435436
.build();

google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@
3939
import java.util.List;
4040
import java.util.UUID;
4141
import java.util.concurrent.ExecutionException;
42+
import java.util.logging.Logger;
4243
import org.junit.Before;
4344
import org.junit.Test;
4445
import org.junit.runner.RunWith;
4546
import org.junit.runners.JUnit4;
4647

4748
@RunWith(JUnit4.class)
4849
public class ConnectionWorkerTest {
50+
private static final Logger log = Logger.getLogger(StreamWriter.class.getName());
4951
private static final String TEST_STREAM_1 = "projects/p1/datasets/d1/tables/t1/streams/s1";
5052
private static final String TEST_STREAM_2 = "projects/p2/datasets/d2/tables/t2/streams/s2";
5153
private static final String TEST_TRACE_ID = "DATAFLOW:job_id";
@@ -84,10 +86,12 @@ public void testMultiplexedAppendSuccess() throws Exception {
8486
StreamWriter sw1 =
8587
StreamWriter.newBuilder(TEST_STREAM_1, client)
8688
.setWriterSchema(createProtoSchema("foo"))
89+
.setLocation("us")
8790
.build();
8891
StreamWriter sw2 =
8992
StreamWriter.newBuilder(TEST_STREAM_2, client)
9093
.setWriterSchema(createProtoSchema("complicate"))
94+
.setLocation("us")
9195
.build();
9296
// We do a pattern of:
9397
// send to stream1, string1
@@ -205,11 +209,20 @@ public void testAppendInSameStream_switchSchema() throws Exception {
205209
// send to stream1, schema1
206210
// ...
207211
StreamWriter sw1 =
208-
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
212+
StreamWriter.newBuilder(TEST_STREAM_1, client)
213+
.setLocation("us")
214+
.setWriterSchema(schema1)
215+
.build();
209216
StreamWriter sw2 =
210-
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema2).build();
217+
StreamWriter.newBuilder(TEST_STREAM_1, client)
218+
.setLocation("us")
219+
.setWriterSchema(schema2)
220+
.build();
211221
StreamWriter sw3 =
212-
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema3).build();
222+
StreamWriter.newBuilder(TEST_STREAM_1, client)
223+
.setLocation("us")
224+
.setWriterSchema(schema3)
225+
.build();
213226
for (long i = 0; i < appendCount; i++) {
214227
switch ((int) i % 4) {
215228
case 0:
@@ -305,10 +318,14 @@ public void testAppendInSameStream_switchSchema() throws Exception {
305318
public void testAppendButInflightQueueFull() throws Exception {
306319
ProtoSchema schema1 = createProtoSchema("foo");
307320
StreamWriter sw1 =
308-
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
321+
StreamWriter.newBuilder(TEST_STREAM_1, client)
322+
.setLocation("us")
323+
.setWriterSchema(schema1)
324+
.build();
309325
ConnectionWorker connectionWorker =
310326
new ConnectionWorker(
311327
TEST_STREAM_1,
328+
"us",
312329
createProtoSchema("foo"),
313330
6,
314331
100000,
@@ -356,10 +373,14 @@ public void testAppendButInflightQueueFull() throws Exception {
356373
public void testThrowExceptionWhileWithinAppendLoop() throws Exception {
357374
ProtoSchema schema1 = createProtoSchema("foo");
358375
StreamWriter sw1 =
359-
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
376+
StreamWriter.newBuilder(TEST_STREAM_1, client)
377+
.setLocation("us")
378+
.setWriterSchema(schema1)
379+
.build();
360380
ConnectionWorker connectionWorker =
361381
new ConnectionWorker(
362382
TEST_STREAM_1,
383+
"us",
363384
createProtoSchema("foo"),
364385
100000,
365386
100000,
@@ -411,6 +432,69 @@ public void testThrowExceptionWhileWithinAppendLoop() throws Exception {
411432
assertThat(ex.getCause()).hasMessageThat().contains("Any exception can happen.");
412433
}
413434

435+
@Test
436+
public void testLocationMismatch() throws Exception {
437+
ProtoSchema schema1 = createProtoSchema("foo");
438+
StreamWriter sw1 =
439+
StreamWriter.newBuilder(TEST_STREAM_1, client)
440+
.setWriterSchema(schema1)
441+
.setLocation("eu")
442+
.build();
443+
ConnectionWorker connectionWorker =
444+
new ConnectionWorker(
445+
TEST_STREAM_1,
446+
"us",
447+
createProtoSchema("foo"),
448+
100000,
449+
100000,
450+
Duration.ofSeconds(100),
451+
FlowController.LimitExceededBehavior.Block,
452+
TEST_TRACE_ID,
453+
client.getSettings());
454+
StatusRuntimeException ex =
455+
assertThrows(
456+
StatusRuntimeException.class,
457+
() ->
458+
sendTestMessage(
459+
connectionWorker,
460+
sw1,
461+
createFooProtoRows(new String[] {String.valueOf(0)}),
462+
0));
463+
assertEquals(
464+
"INVALID_ARGUMENT: StreamWriter with location eu is scheduled to use a connection with location us",
465+
ex.getMessage());
466+
}
467+
468+
@Test
469+
public void testStreamNameMismatch() throws Exception {
470+
ProtoSchema schema1 = createProtoSchema("foo");
471+
StreamWriter sw1 =
472+
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
473+
ConnectionWorker connectionWorker =
474+
new ConnectionWorker(
475+
TEST_STREAM_2,
476+
null,
477+
createProtoSchema("foo"),
478+
100000,
479+
100000,
480+
Duration.ofSeconds(100),
481+
FlowController.LimitExceededBehavior.Block,
482+
TEST_TRACE_ID,
483+
client.getSettings());
484+
StatusRuntimeException ex =
485+
assertThrows(
486+
StatusRuntimeException.class,
487+
() ->
488+
sendTestMessage(
489+
connectionWorker,
490+
sw1,
491+
createFooProtoRows(new String[] {String.valueOf(0)}),
492+
0));
493+
assertEquals(
494+
"INVALID_ARGUMENT: StreamWriter with stream name projects/p1/datasets/d1/tables/t1/streams/s1 is scheduled to use a connection with stream name projects/p2/datasets/d2/tables/t2/streams/s2",
495+
ex.getMessage());
496+
}
497+
414498
@Test
415499
public void testExponentialBackoff() throws Exception {
416500
assertThat(ConnectionWorker.calculateSleepTimeMilli(0)).isEqualTo(1);
@@ -440,6 +524,7 @@ private ConnectionWorker createConnectionWorker(
440524
throws IOException {
441525
return new ConnectionWorker(
442526
streamName,
527+
"us",
443528
createProtoSchema("foo"),
444529
maxRequests,
445530
maxBytes,

0 commit comments

Comments
 (0)