diff --git a/providers/flagd/README.md b/providers/flagd/README.md index 93e0df19c..9e4ada967 100644 --- a/providers/flagd/README.md +++ b/providers/flagd/README.md @@ -110,6 +110,7 @@ Given below are the supported configurations: | port | FLAGD_PORT | int | 8013 | rpc & in-process | | targetUri | FLAGD_TARGET_URI | string | null | rpc & in-process | | tls | FLAGD_TLS | boolean | false | rpc & in-process | +| defaultAuthority | FLAGD_DEFAULT_AUTHORITY | String | null | rpc & in-process | | socketPath | FLAGD_SOCKET_PATH | String | null | rpc & in-process | | certPath | FLAGD_SERVER_CERT_PATH | String | null | rpc & in-process | | deadline | FLAGD_DEADLINE_MS | int | 500 | rpc & in-process & file | @@ -180,6 +181,50 @@ FlagdProvider flagdProvider = new FlagdProvider( > There's a [vulnerability](https://security.snyk.io/vuln/SNYK-JAVA-IONETTY-1042268) in [netty](https://github.com/netty/netty), a transitive dependency of the underlying gRPC libraries used in the flagd-provider that fails to correctly validate certificates. > This will be addressed in netty v5. +### Configuring gRPC credentials and headers + +The `clientInterceptors` and `defaultAuthority` are meant for connection of the in-process resolver to a Sync API implementation on a host/port, that might require special credentials or headers. + +```java +private static ClientInterceptor createHeaderInterceptor() { + return new ClientInterceptor() { + @Override + public ClientCall interceptCall(MethodDescriptor method, CallOptions callOptions, Channel next) { + return new ForwardingClientCall.SimpleForwardingClientCall(next.newCall(method, callOptions)) { + @Override + public void start(Listener responseListener, Metadata headers) { + headers.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "header-value"); + super.start(responseListener, headers); + } + }; + } + }; +} + +private static ClientInterceptor createCallCrednetialsInterceptor(CallCredentials callCredentials) throws IOException { + return new ClientInterceptor() { + @Override + public ClientCall interceptCall(MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions.withCallCredentials(callCredentials)); + } + }; +} + +List clientInterceptors = new ArrayList(2); +clientInterceptors.add(createHeaderInterceptor()); +CallCredentials myCallCredentals = ...; +clientInterceptors.add(createCallCrednetialsInterceptor(myCallCredentials)); + +FlagdProvider flagdProvider = new FlagdProvider( + FlagdOptions.builder() + .host("example.com/flagdSyncApi") + .port(443) + .tls(true) + .defaultAuthority("authority-host.sync.example.com") + .clientInterceptors(clientInterceptors) + .build()); +``` + ### Caching (RPC only) > [!NOTE] diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java index 8c6da726c..56217f972 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java @@ -24,6 +24,7 @@ public final class Config { static final String HOST_ENV_VAR_NAME = "FLAGD_HOST"; static final String PORT_ENV_VAR_NAME = "FLAGD_PORT"; static final String TLS_ENV_VAR_NAME = "FLAGD_TLS"; + static final String DEFAULT_AUTHORITY_ENV_VAR_NAME = "FLAGD_DEFAULT_AUTHORITY"; static final String SOCKET_PATH_ENV_VAR_NAME = "FLAGD_SOCKET_PATH"; static final String SERVER_CERT_PATH_ENV_VAR_NAME = "FLAGD_SERVER_CERT_PATH"; static final String CACHE_ENV_VAR_NAME = "FLAGD_CACHE"; diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java index 3d75d1063..765d07ccc 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java @@ -7,8 +7,10 @@ import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.ImmutableContext; import dev.openfeature.sdk.Structure; +import io.grpc.ClientInterceptor; import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.OpenTelemetry; +import java.util.List; import java.util.function.Function; import lombok.Builder; import lombok.Getter; @@ -164,6 +166,18 @@ public class FlagdOptions { */ private OpenTelemetry openTelemetry; + /** + * gRPC client interceptors to be used when creating a gRPC channel. + */ + @Builder.Default + private List clientInterceptors = null; + + /** + * Authority header to be used when creating a gRPC channel. + */ + @Builder.Default + private String defaultAuthority = fallBackToEnvOrDefault(Config.DEFAULT_AUTHORITY_ENV_VAR_NAME, null); + /** * Builder overwrite in order to customize the "build" method. * diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilder.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilder.java index 51cb3e43f..26bb744a8 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilder.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilder.java @@ -63,6 +63,12 @@ public static ManagedChannel nettyChannel(final FlagdOptions options) { final NettyChannelBuilder builder = NettyChannelBuilder.forTarget(targetUri).keepAliveTime(keepAliveMs, TimeUnit.MILLISECONDS); + if (options.getDefaultAuthority() != null) { + builder.overrideAuthority(options.getDefaultAuthority()); + } + if (options.getClientInterceptors() != null) { + builder.intercept(options.getClientInterceptors()); + } if (options.isTls()) { SslContextBuilder sslContext = GrpcSslContexts.forClient(); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java index 30c08dcd8..725ad8adc 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java @@ -20,7 +20,10 @@ import dev.openfeature.contrib.providers.flagd.resolver.process.storage.MockConnector; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.connector.Connector; +import io.grpc.ClientInterceptor; import io.opentelemetry.api.OpenTelemetry; +import java.util.ArrayList; +import java.util.List; import java.util.function.Function; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -46,12 +49,15 @@ void TestDefaults() { assertNull(builder.getOfflineFlagSourcePath()); assertEquals(Resolver.RPC, builder.getResolverType()); assertEquals(0, builder.getKeepAlive()); + assertNull(builder.getDefaultAuthority()); + assertNull(builder.getClientInterceptors()); } @Test void TestBuilderOptions() { OpenTelemetry openTelemetry = Mockito.mock(OpenTelemetry.class); Connector connector = new MockConnector(null); + List clientInterceptors = new ArrayList(); FlagdOptions flagdOptions = FlagdOptions.builder() .host("https://hosted-flagd") @@ -66,6 +72,8 @@ void TestBuilderOptions() { .resolverType(Resolver.IN_PROCESS) .targetUri("dns:///localhost:8016") .keepAlive(1000) + .defaultAuthority("test-authority.sync.example.com") + .clientInterceptors(clientInterceptors) .build(); assertEquals("https://hosted-flagd", flagdOptions.getHost()); @@ -80,6 +88,8 @@ void TestBuilderOptions() { assertEquals(Resolver.IN_PROCESS, flagdOptions.getResolverType()); assertEquals("dns:///localhost:8016", flagdOptions.getTargetUri()); assertEquals(1000, flagdOptions.getKeepAlive()); + assertEquals("test-authority.sync.example.com", flagdOptions.getDefaultAuthority()); + assertEquals(clientInterceptors, flagdOptions.getClientInterceptors()); } @Test diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java index 554310f76..865fbf984 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java @@ -2,6 +2,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anyString; @@ -11,6 +12,7 @@ import static org.mockito.Mockito.when; import dev.openfeature.contrib.providers.flagd.FlagdOptions; +import io.grpc.ClientInterceptor; import io.grpc.ManagedChannel; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyChannelBuilder; @@ -20,6 +22,8 @@ import io.netty.channel.unix.DomainSocketAddress; import io.netty.handler.ssl.SslContextBuilder; import java.io.File; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLKeyException; import org.junit.jupiter.api.Test; @@ -113,6 +117,83 @@ void testNettyChannel_withTlsAndCert() { } } + @Test + void testNettyChannel_withDefaultAuthority() { + try (MockedStatic nettyMock = mockStatic(NettyChannelBuilder.class)) { + // Mocks + NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class); + ManagedChannel mockChannel = mock(ManagedChannel.class); + nettyMock + .when(() -> NettyChannelBuilder.forTarget("localhost:8080")) + .thenReturn(mockBuilder); + + when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder); + when(mockBuilder.sslContext(any())).thenReturn(mockBuilder); + when(mockBuilder.overrideAuthority(anyString())).thenReturn(mockBuilder); + when(mockBuilder.build()).thenReturn(mockChannel); + + // Input options + FlagdOptions options = FlagdOptions.builder() + .host("localhost") + .port(8080) + .keepAlive(5000) + .tls(true) + .defaultAuthority("test-authority.sync.example.com") + .build(); + + // Call method under test + ManagedChannel channel = ChannelBuilder.nettyChannel(options); + + // Assertions + assertThat(channel).isEqualTo(mockChannel); + nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080")); + verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS); + verify(mockBuilder).sslContext(any()); + verify(mockBuilder).overrideAuthority("test-authority.sync.example.com"); + verify(mockBuilder).build(); + } + } + + @Test + void testNettyChannel_withClientInterceptors() { + try (MockedStatic nettyMock = mockStatic(NettyChannelBuilder.class)) { + // Mocks + NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class); + ManagedChannel mockChannel = mock(ManagedChannel.class); + nettyMock + .when(() -> NettyChannelBuilder.forTarget("localhost:8080")) + .thenReturn(mockBuilder); + + when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder); + when(mockBuilder.sslContext(any())).thenReturn(mockBuilder); + when(mockBuilder.intercept(anyList())).thenReturn(mockBuilder); + when(mockBuilder.build()).thenReturn(mockChannel); + + List clientInterceptors = new ArrayList(); + clientInterceptors.add(mock(ClientInterceptor.class)); + + // Input options + FlagdOptions options = FlagdOptions.builder() + .host("localhost") + .port(8080) + .keepAlive(5000) + .tls(true) + .clientInterceptors(clientInterceptors) + .build(); + + // Call method under test + ManagedChannel channel = ChannelBuilder.nettyChannel(options); + + // Assertions + assertThat(channel).isEqualTo(mockChannel); + nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080")); + verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS); + verify(mockBuilder).sslContext(any()); + verify(mockBuilder).intercept(clientInterceptors); + verify(mockBuilder).build(); + } + } + @ParameterizedTest @ValueSource(strings = {"/incorrect/{uri}/;)"}) void testNettyChannel_withInvalidTargetUri(String uri) {