Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {

allprojects {
group = 'com.strategyobject.substrateclient'
version = '0.0.1-SNAPSHOT'
version = '0.0.2-SNAPSHOT'

repositories {
mavenLocal()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import com.google.common.base.Strings;
import com.strategyobject.substrateclient.common.eventemitter.EventEmitter;
import com.strategyobject.substrateclient.common.eventemitter.EventListener;
import com.strategyobject.substrateclient.common.gc.WeakReferenceFinalizer;
import com.strategyobject.substrateclient.transport.ProviderInterface;
import com.strategyobject.substrateclient.transport.ProviderInterfaceEmitted;
import com.strategyobject.substrateclient.transport.SubscriptionHandler;
Expand All @@ -18,9 +17,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.*;
Expand Down Expand Up @@ -48,7 +44,7 @@ public WsStateSubscription(BiConsumer<Exception, Object> callBack,
@Getter
@Setter
class WsStateAwaiting<T> {
private WeakReference<CompletableFuture<T>> callBack;
private CompletableFuture<T> callback;
private String method;
private List<Object> params;
private SubscriptionHandler subscription;
Expand All @@ -58,14 +54,15 @@ public class WsProvider implements ProviderInterface, AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(WsProvider.class);
private static final int RESUBSCRIBE_TIMEOUT = 20;
private static final Map<String, String> ALIASES = new HashMap<>();
private static final ScheduledExecutorService timedOutHandlerCleaner = Executors
.newScheduledThreadPool(1);

static {
ALIASES.put("chain_finalisedHead", "chain_finalizedHead");
ALIASES.put("chain_subscribeFinalisedHeads", "chain_subscribeFinalizedHeads");
ALIASES.put("chain_unsubscribeFinalisedHeads", "chain_unsubscribeFinalizedHeads");
}

private final ReferenceQueue<CompletableFuture<?>> referenceQueue = new ReferenceQueue<>();
private final RpcCoder coder = new RpcCoder();
private final URI endpoint;
private final Map<String, String> headers;
Expand All @@ -75,13 +72,15 @@ public class WsProvider implements ProviderInterface, AutoCloseable {
private final Map<String, JsonRpcResponseSubscription> waitingForId = new ConcurrentHashMap<>();
private final int heartbeatInterval;
private final AtomicReference<WebSocketClient> webSocket = new AtomicReference<>(null);
private final long responseTimeoutInMs;
private int autoConnectMs;
private volatile boolean isConnected = false;

WsProvider(@NonNull URI endpoint,
int autoConnectMs,
Map<String, String> headers,
int heartbeatInterval) {
int heartbeatInterval,
long responseTimeoutInMs) {
Preconditions.checkArgument(
endpoint.getScheme().matches("(?i)ws|wss"),
"Endpoint should start with 'ws://', received " + endpoint);
Expand All @@ -93,6 +92,7 @@ public class WsProvider implements ProviderInterface, AutoCloseable {
this.autoConnectMs = autoConnectMs;
this.headers = headers;
this.heartbeatInterval = heartbeatInterval;
this.responseTimeoutInMs = responseTimeoutInMs;

if (autoConnectMs > 0) {
this.connect();
Expand Down Expand Up @@ -211,17 +211,14 @@ private <T> CompletableFuture<T> send(String method,
logger.debug("Calling {} {}, {}, {}, {}", id, method, params, json, subscription);

val whenResponseReceived = new CompletableFuture<T>();
val callback = new WeakReferenceFinalizer<>(
whenResponseReceived,
referenceQueue,
() -> this.handlers.remove(id));

this.handlers.put(id, new WsStateAwaiting<>(callback, method, params, subscription));
this.handlers.put(id, new WsStateAwaiting<>(whenResponseReceived, method, params, subscription));

return CompletableFuture.runAsync(() -> this.webSocket.get().send(json))
.whenCompleteAsync((_res, ex) -> {
if (ex != null) {
this.handlers.remove(id);
} else {
scheduleCleanupIfNoResponseWithinTimeout(id);
}
})
.thenCombineAsync(whenResponseReceived, (_a, b) -> b);
Expand Down Expand Up @@ -299,6 +296,23 @@ public CompletableFuture<Boolean> unsubscribe(String type, String method, String
return whenUnsubscribed;
}

private void scheduleCleanupIfNoResponseWithinTimeout(int id) {
timedOutHandlerCleaner.schedule(() -> {
val handler = this.handlers.remove(id);
if (handler == null) {
return;
}

handler
.getCallback()
.completeExceptionally(new TimeoutException(
String.format("The node didn't respond within %s milliseconds.",
responseTimeoutInMs)));
},
responseTimeoutInMs,
TimeUnit.MILLISECONDS);
}

private void emit(ProviderInterfaceEmitted type, Object... args) {
this.eventEmitter.emit(type, args);
}
Expand All @@ -324,12 +338,7 @@ private void onSocketClose(int code, String reason) {

// reject all hanging requests
val wsClosedException = new WsClosedException(errorMessage);
this.handlers.values().forEach(x -> {
val callback = x.getCallBack().get();
if (callback != null) {
callback.completeExceptionally(wsClosedException);
}
});
this.handlers.values().forEach(x -> x.getCallback().completeExceptionally(wsClosedException));
this.handlers.clear();
this.waitingForId.clear();

Expand All @@ -346,7 +355,6 @@ private void onSocketError(Exception ex) {

private void onSocketMessage(String message) {
logger.debug("Received {}", message);
this.cleanCollectedHandlers();

JsonRpcResponse response = RpcCoder.decodeJson(message);
if (Strings.isNullOrEmpty(response.getMethod())) {
Expand All @@ -365,12 +373,11 @@ private <T> void onSocketMessageResult(JsonRpcResponseSingle response) {
return;
}

val callback = Optional.ofNullable(handler.getCallBack().get());
try {
val result = (T) response.getResult();
// first send the result - in case of subs, we may have an update
// immediately if we have some queued results already
callback.ifPresent(x -> x.complete(result));
handler.getCallback().complete(result);

val subscription = handler.getSubscription();
if (subscription != null) {
Expand All @@ -390,7 +397,7 @@ private <T> void onSocketMessageResult(JsonRpcResponseSingle response) {
}
}
} catch (Exception ex) {
callback.ifPresent(x -> x.completeExceptionally(ex));
handler.getCallback().completeExceptionally(ex);
}

this.handlers.remove(id);
Expand Down Expand Up @@ -467,19 +474,12 @@ private void resubscribe() {
}
}

private void cleanCollectedHandlers() {
Reference<?> referenceFromQueue;
while ((referenceFromQueue = referenceQueue.poll()) != null) {
((WeakReferenceFinalizer<?>) referenceFromQueue).finalizeResources();
referenceFromQueue.clear();
}
}

public static class Builder {
private URI endpoint;
private int autoConnectMs = 2500;
private Map<String, String> headers = null;
private int heartbeatInterval = 60;
private long responseTimeoutInMs = 20000;

Builder() {
try {
Expand Down Expand Up @@ -527,8 +527,17 @@ public Builder disableHeartbeats() {
return this;
}

public Builder setResponseTimeout(long timeout, TimeUnit timeUnit) {
this.responseTimeoutInMs = timeUnit.toMillis(timeout);
return this;
}

public WsProvider build() {
return new WsProvider(this.endpoint, this.autoConnectMs, this.headers, this.heartbeatInterval);
return new WsProvider(this.endpoint,
this.autoConnectMs,
this.headers,
this.heartbeatInterval,
this.responseTimeoutInMs);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.strategyobject.substrateclient.tests.containers.SubstrateVersion;
import com.strategyobject.substrateclient.tests.containers.TestSubstrateContainer;
import eu.rekawek.toxiproxy.model.ToxicDirection;
import lombok.SneakyThrows;
import lombok.val;
import org.junit.jupiter.api.Test;
Expand All @@ -10,10 +11,13 @@
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import java.util.Map;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.*;

@Testcontainers
public class WsProviderProxyTest {
Expand All @@ -27,12 +31,9 @@ public class WsProviderProxyTest {
static final ToxiproxyContainer toxiproxy = new ToxiproxyContainer("shopify/toxiproxy")
.withNetwork(network)
.withNetworkAliases("toxiproxy");

final ToxiproxyContainer.ContainerProxy proxy = toxiproxy.getProxy(substrate, 9944);

private static final int HEARTBEAT_INTERVAL = 5;
private static final int WAIT_TIMEOUT = HEARTBEAT_INTERVAL * 2;

final ToxiproxyContainer.ContainerProxy proxy = toxiproxy.getProxy(substrate, 9944);

@Test
void canReconnect() {
Expand Down Expand Up @@ -77,6 +78,68 @@ void canAutoConnectWhenServerAvailable() {
}
}

@Test
@SneakyThrows
void throwsExceptionWhenCanNotSendRequestAndCleanHandler() {
try (val wsProvider = WsProvider.builder()
.setEndpoint(getWsAddress())
.disableAutoConnect()
.build()) {

wsProvider.connect().get(WAIT_TIMEOUT, TimeUnit.SECONDS);
assertTrue(wsProvider.isConnected());

val timeout = proxy
.toxics()
.timeout("timeout", ToxicDirection.UPSTREAM, 1000);

val exception = assertThrows(CompletionException.class,
() -> wsProvider.send("system_version").join());
assertTrue(exception.getCause() instanceof WsClosedException);

val handlers = getHandlersOf(wsProvider);
assertEquals(0, handlers.size());

timeout.remove();
}
}

@Test
@SneakyThrows
void throwsExceptionWhenResponseTimeoutAndCleanHandler() {
val responseTimeout = 500;
try (val wsProvider = WsProvider.builder()
.setEndpoint(getWsAddress())
.setResponseTimeout(responseTimeout, TimeUnit.MILLISECONDS)
.disableAutoConnect()
.build()) {

wsProvider.connect().get(WAIT_TIMEOUT, TimeUnit.SECONDS);
assertTrue(wsProvider.isConnected());

val latency = proxy
.toxics()
.latency("latency", ToxicDirection.DOWNSTREAM, responseTimeout * 3);

val exception = assertThrows(CompletionException.class,
() -> wsProvider.send("system_version").join(),
String.format("The node didn't respond for %s milliseconds.", responseTimeout));
assertTrue(exception.getCause() instanceof TimeoutException);

val handlers = getHandlersOf(wsProvider);
assertEquals(0, handlers.size());

latency.remove();
}
}

private Map<?, ?> getHandlersOf(WsProvider wsProvider) throws NoSuchFieldException, IllegalAccessException {
val handlersFields = wsProvider.getClass().getDeclaredField("handlers");
handlersFields.setAccessible(true);

return (Map<?, ?>) handlersFields.get(wsProvider);
}

private String getWsAddress() {
return String.format("ws://%s:%s", proxy.getContainerIpAddress(), proxy.getProxyPort());
}
Expand Down