Skip to content

Commit 97ae8ee

Browse files
authored
fix(sdk): allow SDK to handle protocols in addresses (#70)
TDFs contain embedded URLs, some of which contain protocols. In order for them to work with GRPC we need to strip off the protocol. The logic for ports is to use one if it is specified, otherwise we use 80 if the protocol is `http`, otherwise use `443`.
1 parent c1bbbb4 commit 97ae8ee

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
import io.opentdf.platform.kas.PublicKeyRequest;
1414
import io.opentdf.platform.kas.RewrapRequest;
1515

16+
import java.net.MalformedURLException;
17+
import java.net.URL;
1618
import java.time.Duration;
1719
import java.time.Instant;
1820
import java.util.ArrayList;
1921
import java.util.Date;
2022
import java.util.HashMap;
2123
import java.util.function.Function;
2224

25+
import static java.lang.String.format;
26+
2327
public class KASClient implements SDK.KAS, AutoCloseable {
2428

2529
private final Function<String, ManagedChannel> channelFactory;
@@ -51,6 +55,33 @@ public String getPublicKey(Config.KASInfo kasInfo) {
5155
.getPublicKey();
5256
}
5357

58+
private String normalizeAddress(String urlString) {
59+
URL url;
60+
try {
61+
url = new URL(urlString);
62+
} catch (MalformedURLException e) {
63+
// if there is no protocol then they either gave us
64+
// a correct address or one we don't know how to fix
65+
return urlString;
66+
}
67+
68+
// otherwise we take the specified port or default
69+
// based on whether the URL uses a scheme that
70+
// implies TLS
71+
int port;
72+
if (url.getPort() == -1) {
73+
if ("http".equals(url.getProtocol())) {
74+
port = 80;
75+
} else {
76+
port = 443;
77+
}
78+
} else {
79+
port = url.getPort();
80+
}
81+
82+
return format("%s:%d", url.getHost(), port);
83+
}
84+
5485
@Override
5586
public synchronized void close() {
5687
var entries = new ArrayList<>(stubs.values());
@@ -103,21 +134,22 @@ public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy) {
103134
private static class CacheEntry {
104135
final ManagedChannel channel;
105136
final AccessServiceGrpc.AccessServiceBlockingStub stub;
106-
107137
private CacheEntry(ManagedChannel channel, AccessServiceGrpc.AccessServiceBlockingStub stub) {
108138
this.channel = channel;
109139
this.stub = stub;
110140
}
111141
}
112142

113-
private synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(String url) {
114-
if (!stubs.containsKey(url)) {
115-
var channel = channelFactory.apply(url);
143+
// make this protected so we can test the address normalization logic
144+
synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(String url) {
145+
var realAddress = normalizeAddress(url);
146+
if (!stubs.containsKey(realAddress)) {
147+
var channel = channelFactory.apply(realAddress);
116148
var stub = AccessServiceGrpc.newBlockingStub(channel);
117-
stubs.put(url, new CacheEntry(channel, stub));
149+
stubs.put(realAddress, new CacheEntry(channel, stub));
118150
}
119151

120-
return stubs.get(url).stub;
152+
return stubs.get(realAddress).stub;
121153
}
122154
}
123155

sdk/src/test/java/io/opentdf/platform/sdk/KASClientTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.text.ParseException;
2525
import java.util.Base64;
2626
import java.util.Random;
27+
import java.util.concurrent.atomic.AtomicReference;
2728
import java.util.function.Function;
2829

2930
import static io.opentdf.platform.sdk.SDKBuilderTest.getRandomPort;
@@ -136,6 +137,29 @@ public void rewrap(RewrapRequest request, StreamObserver<RewrapResponse> respons
136137
}
137138
}
138139

140+
@Test
141+
public void testAddressNormalization() {
142+
var lastAddress = new AtomicReference<String>();
143+
var dpopKeypair = CryptoUtils.generateRSAKeypair();
144+
var dpopKey = new RSAKey.Builder((RSAPublicKey)dpopKeypair.getPublic()).privateKey(dpopKeypair.getPrivate()).build();
145+
var kasClient = new KASClient(addr -> {
146+
lastAddress.set(addr);
147+
return ManagedChannelBuilder.forTarget(addr).build();
148+
}, dpopKey);
149+
150+
var stub = kasClient.getStub("http://localhost:8080");
151+
assertThat(lastAddress.get()).isEqualTo("localhost:8080");
152+
var otherStub = kasClient.getStub("https://localhost:8080");
153+
assertThat(lastAddress.get()).isEqualTo("localhost:8080");
154+
assertThat(stub).isSameAs(otherStub);
155+
156+
kasClient.getStub("https://example.org");
157+
assertThat(lastAddress.get()).isEqualTo("example.org:443");
158+
159+
kasClient.getStub("http://example.org");
160+
assertThat(lastAddress.get()).isEqualTo("example.org:80");
161+
}
162+
139163
private static Server startServer(AccessServiceGrpc.AccessServiceImplBase accessService) throws IOException {
140164
return ServerBuilder
141165
.forPort(getRandomPort())

0 commit comments

Comments
 (0)