Skip to content

Commit 8adba3f

Browse files
sarahchen6mcculls
andauthored
Add buffer size customizability to JDK UDS support (#8629)
* Improve timeout test and add test for buffer sizes * Add ability to set buffer sizes * Update utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java Co-authored-by: Stuart McCulloch <[email protected]> * Change DEFAULT_BUFFER_SIZE to be visible for classes in same package * Fix comment --------- Co-authored-by: Stuart McCulloch <[email protected]>
1 parent ba7123a commit 8adba3f

File tree

2 files changed

+148
-16
lines changed

2 files changed

+148
-16
lines changed

utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
* Subtype UNIX socket for a higher-fidelity impersonation of TCP sockets. This is named "tunneling"
2222
* because it assumes the ultimate destination has a hostname and port.
2323
*
24-
* <p>Bsed on {@link TunnelingUnixSocket}; adapted to use the built-in UDS support added in Java 16.
24+
* <p>Based on {@link TunnelingUnixSocket}; adapted to use the built-in UDS support added in Java
25+
* 16.
2526
*/
2627
final class TunnelingJdkSocket extends Socket {
2728
private final SocketAddress unixSocketAddress;
@@ -34,6 +35,11 @@ final class TunnelingJdkSocket extends Socket {
3435
private boolean shutOut;
3536
private boolean closed;
3637

38+
protected static final int DEFAULT_BUFFER_SIZE = 8192;
39+
// Indicate that the buffer size is not set by initializing to -1
40+
private int sendBufferSize = -1;
41+
private int receiveBufferSize = -1;
42+
3743
TunnelingJdkSocket(final Path path) {
3844
this.unixSocketAddress = UnixDomainSocketAddress.of(path);
3945
}
@@ -114,6 +120,70 @@ public SocketChannel getChannel() {
114120
return unixSocketChannel;
115121
}
116122

123+
@Override
124+
public void setSendBufferSize(int size) throws SocketException {
125+
if (isClosed()) {
126+
throw new SocketException("Socket is closed");
127+
}
128+
if (size < 0) {
129+
throw new IllegalArgumentException("Invalid send buffer size");
130+
}
131+
try {
132+
unixSocketChannel.setOption(java.net.StandardSocketOptions.SO_SNDBUF, size);
133+
sendBufferSize = size;
134+
} catch (IOException e) {
135+
throw new SocketException("Failed to set send buffer size");
136+
}
137+
}
138+
139+
@Override
140+
public int getSendBufferSize() throws SocketException {
141+
if (isClosed()) {
142+
throw new SocketException("Socket is closed");
143+
}
144+
if (sendBufferSize == -1) {
145+
return DEFAULT_BUFFER_SIZE;
146+
}
147+
return sendBufferSize;
148+
}
149+
150+
@Override
151+
public void setReceiveBufferSize(int size) throws SocketException {
152+
if (isClosed()) {
153+
throw new SocketException("Socket is closed");
154+
}
155+
if (size < 0) {
156+
throw new IllegalArgumentException("Invalid receive buffer size");
157+
}
158+
try {
159+
unixSocketChannel.setOption(java.net.StandardSocketOptions.SO_RCVBUF, size);
160+
receiveBufferSize = size;
161+
} catch (IOException e) {
162+
throw new SocketException("Failed to set receive buffer size");
163+
}
164+
}
165+
166+
@Override
167+
public int getReceiveBufferSize() throws SocketException {
168+
if (isClosed()) {
169+
throw new SocketException("Socket is closed");
170+
}
171+
if (receiveBufferSize == -1) {
172+
return DEFAULT_BUFFER_SIZE;
173+
}
174+
return receiveBufferSize;
175+
}
176+
177+
public int getStreamBufferSize() throws SocketException {
178+
if (isClosed()) {
179+
throw new SocketException("Socket is closed");
180+
}
181+
if (sendBufferSize == -1 && receiveBufferSize == -1) {
182+
return DEFAULT_BUFFER_SIZE;
183+
}
184+
return Math.max(sendBufferSize, receiveBufferSize);
185+
}
186+
117187
@Override
118188
public InputStream getInputStream() throws IOException {
119189
if (isClosed()) {
@@ -127,7 +197,7 @@ public InputStream getInputStream() throws IOException {
127197
}
128198

129199
return new InputStream() {
130-
private final ByteBuffer buffer = ByteBuffer.allocate(8192);
200+
private final ByteBuffer buffer = ByteBuffer.allocate(getStreamBufferSize());
131201
private final Selector selector = Selector.open();
132202

133203
{

utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package datadog.common.socket;
22

3+
import static org.junit.jupiter.api.Assertions.*;
34
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
4-
import static org.junit.jupiter.api.Assertions.fail;
55

66
import datadog.trace.api.Config;
77
import java.io.IOException;
8+
import java.io.InputStream;
89
import java.net.InetSocketAddress;
10+
import java.net.SocketException;
911
import java.net.StandardProtocolFamily;
1012
import java.net.UnixDomainSocketAddress;
1113
import java.nio.channels.ServerSocketChannel;
@@ -28,32 +30,92 @@ public void testTimeout() throws Exception {
2830
return;
2931
}
3032

31-
int testTimeout = 3000;
3233
Path socketPath = getSocketPath();
3334
UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath);
3435
startServer(socketAddress);
3536
TunnelingJdkSocket clientSocket = createClient(socketPath);
37+
InputStream inputStream = clientSocket.getInputStream();
3638

37-
// Test that the socket unblocks when timeout is set to >0
38-
clientSocket.setSoTimeout(1000);
39-
assertTimeoutPreemptively(
40-
Duration.ofMillis(testTimeout), () -> clientSocket.getInputStream().read());
39+
int testTimeout = 1000;
40+
clientSocket.setSoTimeout(testTimeout);
41+
assertEquals(testTimeout, clientSocket.getSoTimeout());
4142

42-
// Test that the socket blocks indefinitely when timeout is set to 0, per
43+
long startTime = System.currentTimeMillis();
44+
int readResult = inputStream.read();
45+
long endTime = System.currentTimeMillis();
46+
long readDuration = endTime - startTime;
47+
int timeVariance = 100;
48+
assertTrue(readDuration >= testTimeout && readDuration <= testTimeout + timeVariance);
49+
assertEquals(0, readResult);
50+
51+
int newTimeout = testTimeout / 2;
52+
clientSocket.setSoTimeout(newTimeout);
53+
assertEquals(newTimeout, clientSocket.getSoTimeout());
54+
assertTimeoutPreemptively(Duration.ofMillis(testTimeout), () -> inputStream.read());
55+
56+
// The socket should block indefinitely when timeout is set to 0, per
4357
// https://docs.oracle.com/en/java/javase/16/docs/api//java.base/java/net/Socket.html#setSoTimeout(int).
44-
clientSocket.setSoTimeout(0);
45-
boolean infiniteTimeOut = false;
58+
int infiniteTimeout = 0;
59+
clientSocket.setSoTimeout(infiniteTimeout);
60+
assertEquals(infiniteTimeout, clientSocket.getSoTimeout());
4661
try {
47-
assertTimeoutPreemptively(
48-
Duration.ofMillis(testTimeout), () -> clientSocket.getInputStream().read());
62+
assertTimeoutPreemptively(Duration.ofMillis(testTimeout), () -> inputStream.read());
63+
fail("Read should block indefinitely with infinite timeout");
4964
} catch (AssertionError e) {
50-
infiniteTimeOut = true;
65+
// Expected
5166
}
52-
if (!infiniteTimeOut) {
53-
fail("Test failed: Expected infinite blocking when timeout is set to 0.");
67+
68+
int invalidTimeout = -1;
69+
assertThrows(IllegalArgumentException.class, () -> clientSocket.setSoTimeout(invalidTimeout));
70+
71+
clientSocket.close();
72+
assertThrows(SocketException.class, () -> clientSocket.setSoTimeout(testTimeout));
73+
assertThrows(SocketException.class, () -> clientSocket.getSoTimeout());
74+
75+
isServerRunning.set(false);
76+
}
77+
78+
@Test
79+
public void testBufferSizes() throws Exception {
80+
if (!Config.get().isJdkSocketEnabled()) {
81+
System.out.println(
82+
"TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'.");
83+
return;
5484
}
5585

86+
Path socketPath = getSocketPath();
87+
UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath);
88+
startServer(socketAddress);
89+
TunnelingJdkSocket clientSocket = createClient(socketPath);
90+
91+
assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getSendBufferSize());
92+
assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getReceiveBufferSize());
93+
assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getStreamBufferSize());
94+
95+
int newBufferSize = TunnelingJdkSocket.DEFAULT_BUFFER_SIZE / 2;
96+
clientSocket.setSendBufferSize(newBufferSize);
97+
clientSocket.setReceiveBufferSize(newBufferSize / 2);
98+
assertEquals(newBufferSize, clientSocket.getSendBufferSize());
99+
assertEquals(newBufferSize / 2, clientSocket.getReceiveBufferSize());
100+
assertEquals(newBufferSize, clientSocket.getStreamBufferSize());
101+
102+
int invalidBufferSize = -1;
103+
assertThrows(
104+
IllegalArgumentException.class, () -> clientSocket.setSendBufferSize(invalidBufferSize));
105+
assertThrows(
106+
IllegalArgumentException.class, () -> clientSocket.setReceiveBufferSize(invalidBufferSize));
107+
56108
clientSocket.close();
109+
assertThrows(
110+
SocketException.class,
111+
() -> clientSocket.setSendBufferSize(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE));
112+
assertThrows(
113+
SocketException.class,
114+
() -> clientSocket.setReceiveBufferSize(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE));
115+
assertThrows(SocketException.class, () -> clientSocket.getSendBufferSize());
116+
assertThrows(SocketException.class, () -> clientSocket.getReceiveBufferSize());
117+
assertThrows(SocketException.class, () -> clientSocket.getStreamBufferSize());
118+
57119
isServerRunning.set(false);
58120
}
59121

0 commit comments

Comments
 (0)