Skip to content
22 changes: 19 additions & 3 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDK.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.opentdf.platform.sdk;

import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import io.opentdf.platform.authorization.AuthorizationServiceGrpc;
import io.opentdf.platform.authorization.AuthorizationServiceGrpc.AuthorizationServiceFutureStub;
Expand All @@ -13,12 +14,17 @@
import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceGrpc.SubjectMappingServiceFutureStub;
import io.opentdf.platform.sdk.nanotdf.NanoTDFType;

import javax.net.ssl.TrustManager;
import java.util.Optional;

/**
* The SDK class represents a software development kit for interacting with the opentdf platform. It
* provides various services and stubs for making API calls to the opentdf platform.
*/
public class SDK implements AutoCloseable {
private final Services services;
private final TrustManager trustManager;
private final ClientInterceptor authInterceptor;

@Override
public void close() throws Exception {
Expand Down Expand Up @@ -89,11 +95,21 @@ public KAS kas() {
}
}

SDK(Services services) {
public Optional<TrustManager> getTrustManager() {
return Optional.ofNullable(trustManager);
}

public Optional<ClientInterceptor> getAuthInterceptor() {
return Optional.ofNullable(authInterceptor);
}

SDK(Services services, TrustManager trustManager, ClientInterceptor authInterceptor) {
this.services = services;
this.trustManager = trustManager;
this.authInterceptor = authInterceptor;
}

public Services getServices(){
public Services getServices() {
return this.services;
}
}
}
44 changes: 37 additions & 7 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedTrustManager;
import java.io.File;
import java.io.FileInputStream;
Expand All @@ -29,6 +30,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.function.Function;

/**
* A builder class for creating instances of the SDK class.
Expand Down Expand Up @@ -145,8 +147,9 @@ private GRPCAuthInterceptor getGrpcAuthInterceptor(RSAKey rsaKey) {
.getFieldsOrThrow(PLATFORM_ISSUER)
.getStringValue();

} catch (StatusRuntimeException e) {
throw new SDKException("Error getting the issuer from the platform", e);
} catch (IllegalArgumentException e) {
logger.warn("no `platform_issuer` found in well known configuration. requests from the SDK will be unauthenticated", e);
return null;
}

Issuer issuer = new Issuer(platformIssuer);
Expand All @@ -164,7 +167,20 @@ private GRPCAuthInterceptor getGrpcAuthInterceptor(RSAKey rsaKey) {
return new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI(), sslFactory);
}

SDK.Services buildServices() {
static class ServicesAndInternals {
final ClientInterceptor interceptor;
final TrustManager trustManager;

final SDK.Services services;

ServicesAndInternals(ClientInterceptor interceptor, TrustManager trustManager, SDK.Services services) {
this.interceptor = interceptor;
this.trustManager = trustManager;
this.services = services;
}
}

ServicesAndInternals buildServices() {
RSAKey dpopKey;
try {
dpopKey = new RSAKeyGenerator(2048)
Expand All @@ -176,13 +192,27 @@ SDK.Services buildServices() {
}

var authInterceptor = getGrpcAuthInterceptor(dpopKey);
var channel = getManagedChannelBuilder(platformEndpoint).intercept(authInterceptor).build();
var client = new KASClient(endpoint -> getManagedChannelBuilder(endpoint).intercept(authInterceptor).build(), dpopKey);
return SDK.Services.newServices(channel, client);
ManagedChannel channel;
Function<String, ManagedChannel> managedChannelFactory;
if (authInterceptor == null) {
channel = getManagedChannelBuilder(platformEndpoint).build();
managedChannelFactory = (String endpoint) -> getManagedChannelBuilder(endpoint).build();

} else {
channel = getManagedChannelBuilder(platformEndpoint).intercept(authInterceptor).build();
managedChannelFactory = (String endpoint) -> getManagedChannelBuilder(endpoint).intercept(authInterceptor).build();
}
var client = new KASClient(managedChannelFactory, dpopKey);
return new ServicesAndInternals(
authInterceptor,
sslFactory == null ? null : sslFactory.getTrustManager().orElse(null),
SDK.Services.newServices(channel, client)
);
}

public SDK build() {
return new SDK(buildServices());
var services = buildServices();
return new SDK(services.services, services.trustManager, services.interceptor);
}

/**
Expand Down
86 changes: 82 additions & 4 deletions sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import io.grpc.ClientInterceptor;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import io.opentdf.platform.kas.AccessServiceGrpc;
import io.opentdf.platform.kas.RewrapRequest;
import io.opentdf.platform.kas.RewrapResponse;
import io.opentdf.platform.policy.namespaces.GetNamespaceRequest;
import io.opentdf.platform.policy.namespaces.GetNamespaceResponse;
import io.opentdf.platform.policy.namespaces.NamespaceServiceGrpc;
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest;
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
Expand All @@ -37,16 +40,16 @@
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Base64;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;



public class SDKBuilderTest {

final String EXAMPLE_COM_PEM="-----BEGIN CERTIFICATE-----\n" +
final String EXAMPLE_COM_PEM= "-----BEGIN CERTIFICATE-----\n" +
"MIIBqTCCARKgAwIBAgIIT0xFd/5uogEwDQYJKoZIhvcNAQEFBQAwFjEUMBIGA1UEAxMLZXhhbXBs\n" +
"ZS5jb20wIBcNMTcwMTIwMTczOTIwWhgPOTk5OTEyMzEyMzU5NTlaMBYxFDASBgNVBAMTC2V4YW1w\n" +
"bGUuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC2Tl2MdaUFmjAaYwmEwgEVRfVqwJO4\n" +
Expand Down Expand Up @@ -224,8 +227,12 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
certificate()).build());
}

services = servicesBuilder
.buildServices();
var servicesAndComponents = servicesBuilder.buildServices();
if (useSSL) {
assertThat(servicesAndComponents.trustManager).isNotNull();
}
assertThat(servicesAndComponents.interceptor).isNotNull();
services = servicesAndComponents.services;

assertThat(services).isNotNull();

Expand Down Expand Up @@ -295,6 +302,77 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
}
}

/**
* If auth is disabled then the `platform_issuer` isn't returned during bootstrapping. The SDK
* should still function without auth if auth is disabled on the server
* @throws IOException
*/
@Test
public void testSdkWithNoIssuerMakesRequests() throws IOException {
WellKnownServiceGrpc.WellKnownServiceImplBase wellKnownService = new WellKnownServiceGrpc.WellKnownServiceImplBase() {
@Override
public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request, StreamObserver<GetWellKnownConfigurationResponse> responseObserver) {
// don't return a platform issuer
responseObserver.onNext(GetWellKnownConfigurationResponse.getDefaultInstance());
responseObserver.onCompleted();
}
};

var authHeader = new AtomicReference<String>(null);
var getNsCalled = new AtomicReference<Boolean>(false);

var platformServices = ServerBuilder
.forPort(getRandomPort())
.directExecutor()
.intercept(new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
authHeader.set(
headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))
);
return next.startCall(call, headers);
}
})
.addService(wellKnownService)
.addService(new NamespaceServiceGrpc.NamespaceServiceImplBase() {
@Override
public void getNamespace(GetNamespaceRequest request, StreamObserver<GetNamespaceResponse> responseObserver) {
getNsCalled.set(true);
responseObserver.onNext(GetNamespaceResponse.getDefaultInstance());
responseObserver.onCompleted();
}
})
.build();

SDK sdk;
try {
platformServices.start();

sdk = SDKBuilder.newBuilder()
.clientSecret("user", "password")
.platformEndpoint("localhost:" + platformServices.getPort())
.useInsecurePlaintextConnection(true)
.build();
assertThat(sdk.getAuthInterceptor()).isEmpty();


try {
sdk.getServices().namespaces().getNamespace(GetNamespaceRequest.getDefaultInstance()).get();
} catch (StatusRuntimeException ignored) {
} catch (ExecutionException e) {
throw new RuntimeException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}

assertThat(getNsCalled.get()).isTrue();
assertThat(authHeader.get()).isNullOrEmpty();
} finally {
platformServices.shutdownNow();
}
}

public static int getRandomPort() throws IOException {
int randomPort;
try (ServerSocket socket = new ServerSocket(0)) {
Expand Down
7 changes: 5 additions & 2 deletions sdk/src/test/java/io/opentdf/platform/sdk/TDFE2ETest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ public void createAndDecryptTdfIT() throws Exception {
.clientSecret("opentdf-sdk", "secret")
.useInsecurePlaintextConnection(true)
.platformEndpoint("localhost:8080")
.buildServices();
.buildServices()
.services;


var kasInfo = new Config.KASInfo();
kasInfo.URL = "localhost:8080";
Expand Down Expand Up @@ -50,7 +52,8 @@ public void createAndDecryptNanoTDF() throws Exception {
.clientSecret("opentdf-sdk", "secret")
.useInsecurePlaintextConnection(true)
.platformEndpoint("localhost:8080")
.buildServices();
.buildServices()
.services;

var kasInfo = new Config.KASInfo();
kasInfo.URL = "http://localhost:8080";
Expand Down