Skip to content

Commit 6ed6266

Browse files
authored
fix: terminate connection for invalid messages (#154)
* test: add tests for invalid messages * fix: terminate connection for invalid messages The backend should protect itself against a malicious client by dropping the connection if it receives an invalid message, or a message that claims to be longer than is possible. * fix: terminate connection for invalid messages The backend should protect itself against a malicious client by dropping the connection if it receives an invalid message, or a message that claims to be longer than is possible. * fix: increase max message length * fix: remove max message length check
1 parent 7298244 commit 6ed6266

File tree

5 files changed

+150
-2
lines changed

5 files changed

+150
-2
lines changed

src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ public void run() {
187187
try {
188188
message.nextHandler();
189189
message.send();
190+
} catch (IllegalArgumentException | IllegalStateException | EOFException fatalException) {
191+
this.handleError(output, fatalException);
192+
this.status = ConnectionStatus.TERMINATED;
190193
} catch (Exception e) {
191194
this.handleError(output, e);
192195
}
@@ -226,7 +229,9 @@ public void run() {
226229
/** Called when a Terminate message is received. This closes this {@link ConnectionHandler}. */
227230
public void handleTerminate() {
228231
closeAllPortals();
229-
this.spannerConnection.close();
232+
if (this.spannerConnection != null) {
233+
this.spannerConnection.close();
234+
}
230235
this.status = ConnectionStatus.TERMINATED;
231236
}
232237

@@ -237,6 +242,8 @@ public void handleTerminate() {
237242
void terminate() throws IOException {
238243
if (this.status != ConnectionStatus.TERMINATED) {
239244
handleTerminate();
245+
}
246+
if (!socket.isClosed()) {
240247
socket.close();
241248
}
242249
}

src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,16 @@ void runServer() throws IOException {
174174
getLocalPort(), e));
175175
} finally {
176176
for (ConnectionHandler handler : this.handlers) {
177-
handler.terminate();
177+
try {
178+
handler.terminate();
179+
} catch (Exception exception) {
180+
logger.log(
181+
Level.WARNING,
182+
exception,
183+
() ->
184+
String.format(
185+
"Connection handler %s could not be terminated: %s", handler, exception));
186+
}
178187
}
179188
logger.log(Level.INFO, () -> String.format("Socket on port %d stopped", getLocalPort()));
180189
notifyStopped();

src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/BootstrapMessage.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
*/
3434
@InternalApi
3535
public abstract class BootstrapMessage extends WireMessage {
36+
private static final int MAX_BOOTSTRAP_MESSAGE_LENGTH = 1 << 8;
3637

3738
public BootstrapMessage(ConnectionHandler connection, int length) {
3839
super(connection, length);
@@ -48,6 +49,9 @@ public BootstrapMessage(ConnectionHandler connection, int length) {
4849
*/
4950
public static BootstrapMessage create(ConnectionHandler connection) throws Exception {
5051
int length = connection.getConnectionMetadata().getInputStream().readInt();
52+
if (length > MAX_BOOTSTRAP_MESSAGE_LENGTH) {
53+
throw new IllegalArgumentException("Invalid bootstrap message length: " + length);
54+
}
5155
int protocol = connection.getConnectionMetadata().getInputStream().readInt();
5256
switch (protocol) {
5357
case SSLMessage.IDENTIFIER:

src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/WireMessage.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import com.google.api.core.InternalApi;
1818
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
19+
import com.google.common.base.Preconditions;
1920
import java.io.DataInputStream;
2021
import java.io.DataOutputStream;
2122
import java.io.IOException;
@@ -37,6 +38,7 @@ public abstract class WireMessage {
3738
protected ConnectionHandler connection;
3839

3940
public WireMessage(ConnectionHandler connection, int length) {
41+
Preconditions.checkArgument(length >= 4);
4042
this.connection = connection;
4143
this.inputStream = connection.getConnectionMetadata().getInputStream();
4244
this.outputStream = connection.getConnectionMetadata().getOutputStream();
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright 2022 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package com.google.cloud.spanner.pgadapter;
16+
17+
import static org.junit.Assert.assertEquals;
18+
19+
import com.google.cloud.spanner.pgadapter.wireprotocol.SSLMessage;
20+
import com.google.cloud.spanner.pgadapter.wireprotocol.StartupMessage;
21+
import java.io.DataInputStream;
22+
import java.io.DataOutputStream;
23+
import java.io.IOException;
24+
import java.net.Socket;
25+
import java.nio.charset.StandardCharsets;
26+
import org.junit.Test;
27+
import org.junit.runner.RunWith;
28+
import org.junit.runners.JUnit4;
29+
30+
@RunWith(JUnit4.class)
31+
public class InvalidMessagesTest extends AbstractMockServerTest {
32+
33+
@Test
34+
public void testConnectionWithoutMessages() throws IOException {
35+
try (Socket ignored = new Socket("localhost", pgServer.getLocalPort())) {
36+
// Do nothing, just close the socket again.
37+
}
38+
}
39+
40+
@Test
41+
public void testGarbledStartupMessage() throws IOException {
42+
try (Socket socket = new Socket("localhost", pgServer.getLocalPort())) {
43+
socket.getOutputStream().write("foo".getBytes(StandardCharsets.UTF_8));
44+
}
45+
}
46+
47+
@Test
48+
public void testDropConnectionAfterStartup() throws IOException {
49+
try (Socket socket = new Socket("localhost", pgServer.getLocalPort())) {
50+
try (DataOutputStream outputStream = new DataOutputStream(socket.getOutputStream())) {
51+
// Send a startup message and then quit.
52+
outputStream.writeInt(8); // length == 8
53+
outputStream.writeInt(StartupMessage.IDENTIFIER);
54+
outputStream.flush();
55+
}
56+
}
57+
}
58+
59+
@Test
60+
public void testDropConnectionAfterRefusedSSL() throws IOException {
61+
try (Socket socket = new Socket("localhost", pgServer.getLocalPort())) {
62+
try (DataInputStream inputStream = new DataInputStream(socket.getInputStream());
63+
DataOutputStream outputStream = new DataOutputStream(socket.getOutputStream())) {
64+
// Request SSL.
65+
outputStream.writeInt(8); // length == 8
66+
outputStream.writeInt(SSLMessage.IDENTIFIER);
67+
outputStream.flush();
68+
69+
// Verify that it is refused by the server.
70+
byte response = inputStream.readByte();
71+
assertEquals('N', response);
72+
}
73+
}
74+
}
75+
76+
@Test
77+
public void testDropConnectionAfterStartupMessage() throws IOException {
78+
try (Socket socket = new Socket("localhost", pgServer.getLocalPort())) {
79+
try (DataInputStream inputStream = new DataInputStream(socket.getInputStream());
80+
DataOutputStream outputStream = new DataOutputStream(socket.getOutputStream())) {
81+
// Request startup.
82+
outputStream.writeInt(17);
83+
outputStream.writeInt(StartupMessage.IDENTIFIER);
84+
outputStream.writeBytes("user");
85+
outputStream.writeByte(0);
86+
outputStream.writeBytes("foo");
87+
outputStream.writeByte(0);
88+
outputStream.flush();
89+
90+
// Verify that the server responds with auth OK.
91+
assertEquals('R', inputStream.readByte());
92+
assertEquals(8, inputStream.readInt());
93+
assertEquals(0, inputStream.readInt()); // 0 == success
94+
}
95+
}
96+
}
97+
98+
@Test
99+
public void testSendGarbageAfterStartupMessage() throws IOException {
100+
try (Socket socket = new Socket("localhost", pgServer.getLocalPort())) {
101+
try (DataInputStream inputStream = new DataInputStream(socket.getInputStream());
102+
DataOutputStream outputStream = new DataOutputStream(socket.getOutputStream())) {
103+
// Request startup.
104+
outputStream.writeInt(17);
105+
outputStream.writeInt(StartupMessage.IDENTIFIER);
106+
outputStream.writeBytes("user");
107+
outputStream.writeByte(0);
108+
outputStream.writeBytes("foo");
109+
outputStream.writeByte(0);
110+
outputStream.flush();
111+
112+
// Then send a random message with no meaning and drop the connection.
113+
outputStream.writeInt(20);
114+
outputStream.writeChar(' ');
115+
outputStream.flush();
116+
117+
// Read until the end of the stream. The stream should be closed by the backend.
118+
int bytesRead = 0;
119+
while (inputStream.read() > -1 && bytesRead < 1 << 16) {
120+
bytesRead++;
121+
}
122+
assertEquals(-1, inputStream.read());
123+
}
124+
}
125+
}
126+
}

0 commit comments

Comments
 (0)