Skip to content

Commit

Permalink
Migrate many usages of TestUtils.loadCert() to the public TlsTesting
Browse files Browse the repository at this point in the history
TlsTesting.loadCert() is a public API and so should be preferred over
our internal utility. It avoids creating a temp file that has to be
deleted by a shutdown hook. Usages that needed a file were not migrated.
  • Loading branch information
ejona86 committed May 10, 2023
1 parent f229aed commit 74b515e
Show file tree
Hide file tree
Showing 19 changed files with 124 additions and 136 deletions.
24 changes: 8 additions & 16 deletions authz/src/test/java/io/grpc/authz/AuthorizationEnd2EndTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
import io.grpc.TlsServerCredentials;
import io.grpc.TlsServerCredentials.ClientAuth;
import io.grpc.internal.FakeClock;
import io.grpc.internal.testing.TestUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.TlsTesting;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
Expand Down Expand Up @@ -343,11 +343,6 @@ public void staticAuthzDeniesRpcWithPrincipalsFieldOnUnauthenticatedConnectionTe
@Test
public void staticAuthzAllowsRpcWithPrincipalsFieldOnMtlsAuthenticatedConnectionTest()
throws Exception {
File caCertFile = TestUtils.loadCert(CA_PEM_FILE);
File serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE);
File serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE);
File clientKey0File = TestUtils.loadCert(CLIENT_0_KEY_FILE);
File clientCert0File = TestUtils.loadCert(CLIENT_0_PEM_FILE);
String policy = "{"
+ " \"name\" : \"authz\" ,"
+ " \"allow_rules\": ["
Expand All @@ -361,24 +356,21 @@ public void staticAuthzAllowsRpcWithPrincipalsFieldOnMtlsAuthenticatedConnection
+ "}";
AuthorizationServerInterceptor interceptor = createStaticAuthorizationInterceptor(policy);
ServerCredentials serverCredentials = TlsServerCredentials.newBuilder()
.keyManager(serverCert0File, serverKey0File)
.trustManager(caCertFile)
.keyManager(TlsTesting.loadCert(SERVER_0_PEM_FILE), TlsTesting.loadCert(SERVER_0_KEY_FILE))
.trustManager(TlsTesting.loadCert(CA_PEM_FILE))
.clientAuth(ClientAuth.REQUIRE)
.build();
initServerWithAuthzInterceptor(interceptor, serverCredentials);
ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder()
.keyManager(clientCert0File, clientKey0File)
.trustManager(caCertFile)
.keyManager(TlsTesting.loadCert(CLIENT_0_PEM_FILE), TlsTesting.loadCert(CLIENT_0_KEY_FILE))
.trustManager(TlsTesting.loadCert(CA_PEM_FILE))
.build();
getStub(channelCredentials).unaryRpc(SimpleRequest.getDefaultInstance());
}

@Test
public void staticAuthzAllowsRpcWithPrincipalsFieldOnTlsAuthenticatedConnectionTest()
throws Exception {
File caCertFile = TestUtils.loadCert(CA_PEM_FILE);
File serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE);
File serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE);
String policy = "{"
+ " \"name\" : \"authz\" ,"
+ " \"allow_rules\": ["
Expand All @@ -392,13 +384,13 @@ public void staticAuthzAllowsRpcWithPrincipalsFieldOnTlsAuthenticatedConnectionT
+ "}";
AuthorizationServerInterceptor interceptor = createStaticAuthorizationInterceptor(policy);
ServerCredentials serverCredentials = TlsServerCredentials.newBuilder()
.keyManager(serverCert0File, serverKey0File)
.trustManager(caCertFile)
.keyManager(TlsTesting.loadCert(SERVER_0_PEM_FILE), TlsTesting.loadCert(SERVER_0_KEY_FILE))
.trustManager(TlsTesting.loadCert(CA_PEM_FILE))
.clientAuth(ClientAuth.OPTIONAL)
.build();
initServerWithAuthzInterceptor(interceptor, serverCredentials);
ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder()
.trustManager(caCertFile)
.trustManager(TlsTesting.loadCert(CA_PEM_FILE))
.build();
getStub(channelCredentials).unaryRpc(SimpleRequest.getDefaultInstance());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
import io.grpc.benchmarks.proto.Control;
import io.grpc.benchmarks.proto.Stats;
import io.grpc.benchmarks.qps.AsyncServer;
import io.grpc.internal.testing.TestUtils;
import io.grpc.testing.TlsTesting;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import java.io.File;
import java.io.InputStream;
import java.lang.management.ManagementFactory;
import java.util.List;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -115,8 +115,8 @@ final class LoadServer {
}
}
if (config.hasSecurityParams()) {
File cert = TestUtils.loadCert("server1.pem");
File key = TestUtils.loadCert("server1.key");
InputStream cert = TlsTesting.loadCert("server1.pem");
InputStream key = TlsTesting.loadCert("server1.key");
serverBuilder.useTransportSecurity(cert, key);
}
benchmarkService = new AsyncServer.BenchmarkServiceImpl();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@
import io.grpc.benchmarks.Utils;
import io.grpc.benchmarks.proto.BenchmarkServiceGrpc;
import io.grpc.benchmarks.proto.Messages;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import io.grpc.stub.StreamObservers;
import io.grpc.testing.TlsTesting;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Iterator;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory;
Expand Down Expand Up @@ -164,8 +164,8 @@ static Server newServer(ServerConfiguration config) throws IOException {
System.out.println("Using fake CA for TLS certificate.\n"
+ "Run the Java client with --tls --testca");

File cert = TestUtils.loadCert("server1.pem");
File key = TestUtils.loadCert("server1.key");
InputStream cert = TlsTesting.loadCert("server1.pem");
InputStream key = TlsTesting.loadCert("server1.key");
builder.useTransportSecurity(cert, key);
}
if (config.directExecutor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.TlsTesting;
import io.netty.handler.ssl.SslContext;
import java.io.IOException;
import java.net.InetAddress;
Expand Down Expand Up @@ -345,7 +345,7 @@ private ManagedChannel createChannel(InetSocketAddress address) {
if (useTestCa) {
try {
sslContext = GrpcSslContexts.forClient().trustManager(
TestUtils.loadCert("ca.pem")).build();
TlsTesting.loadCert("ca.pem")).build();
} catch (Exception ex) {
throw new RuntimeException(ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.okhttp.InternalOkHttpChannelBuilder;
import io.grpc.okhttp.OkHttpChannelBuilder;
import io.grpc.testing.TlsTesting;
import java.io.File;
import java.io.FileInputStream;
import java.nio.charset.Charset;
Expand Down Expand Up @@ -537,7 +538,7 @@ protected ManagedChannelBuilder<?> createChannelBuilder() {
} else {
try {
channelCredentials = TlsChannelCredentials.newBuilder()
.trustManager(TestUtils.loadCert("ca.pem"))
.trustManager(TlsTesting.loadCert("ca.pem"))
.build();
} catch (Exception ex) {
throw new RuntimeException(ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.grpc.alts.AltsServerCredentials;
import io.grpc.internal.testing.TestUtils;
import io.grpc.services.MetricRecorder;
import io.grpc.testing.TlsTesting;
import io.grpc.xds.orca.OrcaMetricReportingServerInterceptor;
import io.grpc.xds.orca.OrcaServiceImpl;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -151,7 +152,7 @@ void start() throws Exception {
}
} else if (useTls) {
serverCreds = TlsServerCredentials.create(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));
TlsTesting.loadCert("server1.pem"), TlsTesting.loadCert("server1.key"));
} else {
serverCreds = InsecureServerCredentials.create();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
import io.grpc.TlsServerCredentials;
import io.grpc.internal.testing.TestUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.TlsTesting;
import io.grpc.testing.integration.Messages.ResponseParameters;
import io.grpc.testing.integration.Messages.StreamingOutputCallRequest;
import io.grpc.testing.integration.Messages.StreamingOutputCallResponse;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -188,13 +188,9 @@ public void serverStreamingTest() throws Exception {
* Creates and starts a new {@link TestServiceImpl} server.
*/
private Server newServer() throws IOException {
File serverCertChainFile = TestUtils.loadCert("server1.pem");
File serverPrivateKeyFile = TestUtils.loadCert("server1.key");
File serverTrustedCaCerts = TestUtils.loadCert("ca.pem");

ServerCredentials serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCertChainFile, serverPrivateKeyFile)
.trustManager(serverTrustedCaCerts)
.keyManager(TlsTesting.loadCert("server1.pem"), TlsTesting.loadCert("server1.key"))
.trustManager(TlsTesting.loadCert("ca.pem"))
.clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
.build();

Expand All @@ -205,13 +201,9 @@ private Server newServer() throws IOException {
}

private ManagedChannel newClientChannel() throws IOException {
File clientCertChainFile = TestUtils.loadCert("client.pem");
File clientPrivateKeyFile = TestUtils.loadCert("client.key");
File clientTrustedCaCerts = TestUtils.loadCert("ca.pem");

ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
.keyManager(clientCertChainFile, clientPrivateKeyFile)
.trustManager(clientTrustedCaCerts)
.keyManager(TlsTesting.loadCert("client.pem"), TlsTesting.loadCert("client.key"))
.trustManager(TlsTesting.loadCert("ca.pem"))
.build();

return Grpc.newChannelBuilder("localhost:" + server.getPort(), channelCreds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.grpc.netty.InternalNettyServerBuilder;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.testing.TlsTesting;
import java.io.IOException;
import java.net.InetSocketAddress;
import org.junit.Test;
Expand All @@ -43,8 +44,8 @@ protected ServerBuilder<?> getServerBuilder() {
// Starts the server with HTTPS.
try {
ServerCredentials serverCreds = TlsServerCredentials.newBuilder()
.keyManager(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
.trustManager(TestUtils.loadCert("ca.pem"))
.keyManager(TlsTesting.loadCert("server1.pem"), TlsTesting.loadCert("server1.key"))
.trustManager(TlsTesting.loadCert("ca.pem"))
.clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
.build();
NettyServerBuilder builder = NettyServerBuilder.forPort(0, serverCreds)
Expand All @@ -62,8 +63,8 @@ protected ServerBuilder<?> getServerBuilder() {
protected NettyChannelBuilder createChannelBuilder() {
try {
ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
.keyManager(TestUtils.loadCert("client.pem"), TestUtils.loadCert("client.key"))
.trustManager(TestUtils.loadCert("ca.pem"))
.keyManager(TlsTesting.loadCert("client.pem"), TlsTesting.loadCert("client.key"))
.trustManager(TlsTesting.loadCert("ca.pem"))
.build();
NettyChannelBuilder builder = NettyChannelBuilder
.forAddress("localhost", ((InetSocketAddress) getListenAddress()).getPort(), channelCreds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.grpc.okhttp.OkHttpChannelBuilder;
import io.grpc.okhttp.internal.Platform;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.TlsTesting;
import io.grpc.testing.integration.EmptyProtos.Empty;
import java.io.IOException;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -68,7 +69,7 @@ protected ServerBuilder<?> getServerBuilder() {
// Starts the server with HTTPS.
try {
ServerCredentials serverCreds = TlsServerCredentials.create(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));
TlsTesting.loadCert("server1.pem"), TlsTesting.loadCert("server1.key"));
NettyServerBuilder builder = NettyServerBuilder.forPort(0, serverCreds)
.flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE);
Expand All @@ -86,7 +87,7 @@ protected OkHttpChannelBuilder createChannelBuilder() {
ChannelCredentials channelCreds;
try {
channelCreds = TlsChannelCredentials.newBuilder()
.trustManager(TestUtils.loadCert("ca.pem"))
.trustManager(TlsTesting.loadCert("ca.pem"))
.build();
} catch (IOException ex) {
throw new RuntimeException(ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.grpc.okhttp.OkHttpChannelBuilder;
import io.grpc.okhttp.OkHttpServerBuilder;
import io.grpc.stub.MetadataUtils;
import io.grpc.testing.TlsTesting;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -87,7 +88,7 @@ protected ServerBuilder<?> getServerBuilder() {
ServerCredentials serverCreds;
try {
serverCreds = TlsServerCredentials.create(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));
TlsTesting.loadCert("server1.pem"), TlsTesting.loadCert("server1.key"));
} catch (IOException ex) {
throw new RuntimeException(ex);
}
Expand Down Expand Up @@ -115,7 +116,7 @@ protected ManagedChannelBuilder<?> createChannelBuilder() {
ChannelCredentials channelCreds;
try {
channelCreds = TlsChannelCredentials.newBuilder()
.trustManager(TestUtils.loadCert("ca.pem"))
.trustManager(TlsTesting.loadCert("ca.pem"))
.build();
} catch (Exception ex) {
throw new RuntimeException(ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import io.grpc.Server;
import io.grpc.ServerCredentials;
import io.grpc.TlsServerCredentials;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettySslContextChannelCredentials;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.TlsTesting;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
Expand Down Expand Up @@ -112,13 +112,13 @@ public void basic() throws Exception {
@Test
public void tcnative() throws Exception {
ServerCredentials serverCreds = TlsServerCredentials.create(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));
TlsTesting.loadCert("server1.pem"), TlsTesting.loadCert("server1.key"));
server = Grpc.newServerBuilderForPort(0, serverCreds)
.addService(new SimpleServiceImpl())
.build().start();
ChannelCredentials creds = NettySslContextChannelCredentials.create(
GrpcSslContexts.configure(SslContextBuilder.forClient(), SslProvider.OPENSSL)
.trustManager(TestUtils.loadCert("ca.pem")).build());
.trustManager(TlsTesting.loadCert("ca.pem")).build());
channel = Grpc.newChannelBuilder("localhost:" + server.getPort(), creds)
.overrideAuthority("foo.test.google.fr")
.build();
Expand Down
22 changes: 8 additions & 14 deletions netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.grpc.TlsServerCredentials.ClientAuth;
import io.grpc.internal.testing.TestUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.TlsTesting;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
Expand Down Expand Up @@ -98,20 +99,13 @@ public void setUp()
serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE);
clientKey0File = TestUtils.loadCert(CLIENT_0_KEY_FILE);
clientCert0File = TestUtils.loadCert(CLIENT_0_PEM_FILE);
caCert = CertificateUtils.getX509Certificates(
TestUtils.class.getResourceAsStream("/certs/" + CA_PEM_FILE));
serverKey0 = CertificateUtils.getPrivateKey(
TestUtils.class.getResourceAsStream("/certs/" + SERVER_0_KEY_FILE));
serverCert0 = CertificateUtils.getX509Certificates(
TestUtils.class.getResourceAsStream("/certs/" + SERVER_0_PEM_FILE));
clientKey0 = CertificateUtils.getPrivateKey(
TestUtils.class.getResourceAsStream("/certs/" + CLIENT_0_KEY_FILE));
clientCert0 = CertificateUtils.getX509Certificates(
TestUtils.class.getResourceAsStream("/certs/" + CLIENT_0_PEM_FILE));
serverKeyBad = CertificateUtils.getPrivateKey(
TestUtils.class.getResourceAsStream("/certs/" + SERVER_BAD_KEY_FILE));
serverCertBad = CertificateUtils.getX509Certificates(
TestUtils.class.getResourceAsStream("/certs/" + SERVER_BAD_PEM_FILE));
caCert = CertificateUtils.getX509Certificates(TlsTesting.loadCert(CA_PEM_FILE));
serverKey0 = CertificateUtils.getPrivateKey(TlsTesting.loadCert(SERVER_0_KEY_FILE));
serverCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_0_PEM_FILE));
clientKey0 = CertificateUtils.getPrivateKey(TlsTesting.loadCert(CLIENT_0_KEY_FILE));
clientCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(CLIENT_0_PEM_FILE));
serverKeyBad = CertificateUtils.getPrivateKey(TlsTesting.loadCert(SERVER_BAD_KEY_FILE));
serverCertBad = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_BAD_PEM_FILE));
}

@After
Expand Down
Loading

0 comments on commit 74b515e

Please sign in to comment.