Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public JSONObject targetSchemaJSONObject(final String targetName) {
new JSONTokener(
this.targetSchemas.computeIfAbsent(
targetName,
tn -> this.getClass().getClassLoader().getResourceAsStream(tn)
tn -> this.getClass().getClassLoader().getResourceAsStream(this.targetSchemaPaths.get(tn))
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import software.amazon.awssdk.http.HttpStatusFamily;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.utils.StringUtils;
import software.amazon.cloudformation.encryption.Cipher;
import software.amazon.cloudformation.encryption.KMSCipher;
import software.amazon.cloudformation.exceptions.BaseHandlerException;
Expand Down Expand Up @@ -67,7 +66,6 @@
import software.amazon.cloudformation.resource.SchemaValidator;
import software.amazon.cloudformation.resource.Serializer;
import software.amazon.cloudformation.resource.Validator;
import software.amazon.cloudformation.resource.exceptions.ValidationException;

public abstract class HookAbstractWrapper<TargetT, CallbackT, ConfigurationT> {

Expand All @@ -90,7 +88,6 @@ public abstract class HookAbstractWrapper<TargetT, CallbackT, ConfigurationT> {
final SchemaValidator validator;
final TypeReference<HookInvocationRequest<ConfigurationT, CallbackT>> typeReference;

private MetricsPublisher platformMetricsPublisher;
private MetricsPublisher providerMetricsPublisher;

private CloudWatchLogHelper cloudWatchLogHelper;
Expand All @@ -112,7 +109,6 @@ protected HookAbstractWrapper() {
public HookAbstractWrapper(final CredentialsProvider providerCredentialsProvider,
final CloudWatchLogPublisher providerEventsLogger,
final LogPublisher platformEventsLogger,
final MetricsPublisher platformMetricsPublisher,
final MetricsPublisher providerMetricsPublisher,
final SchemaValidator validator,
final Serializer serializer,
Expand All @@ -123,7 +119,6 @@ public HookAbstractWrapper(final CredentialsProvider providerCredentialsProvider
this.cloudWatchLogsProvider = new CloudWatchLogsProvider(this.providerCredentialsProvider, httpClient);
this.providerEventsLogger = providerEventsLogger;
this.platformLogPublisher = platformEventsLogger;
this.platformMetricsPublisher = platformMetricsPublisher;
this.providerMetricsPublisher = providerMetricsPublisher;
this.serializer = serializer;
this.validator = validator;
Expand All @@ -147,22 +142,14 @@ private void initialiseRuntime(final String hookTypeName,
this.loggerProxy = new LoggerProxy();
this.loggerProxy.addLogPublisher(this.platformLogPublisher);

// Initialisation skipped if dependencies were set during injection (in unit
// tests).

// Initialize a KMS cipher to decrypt customer credentials in HookRequestData
if (this.cipher == null && hookEncryptionKeyArn != null && hookEncryptionKeyRole != null) {
this.cipher = new KMSCipher(hookEncryptionKeyArn, hookEncryptionKeyRole);
}

// Initialisation skipped if dependencies were set during injection (in unit
// tests).
// e.g. "if (this.platformMetricsPublisher == null)"
if (this.platformMetricsPublisher == null) {
// platformMetricsPublisher needs aws account id to differentiate metrics
// namespace
this.platformMetricsPublisher = new HookMetricsPublisherImpl(this.platformLoggerProxy, awsAccountId, hookTypeName);
}
this.metricsPublisherProxy.addMetricsPublisher(this.platformMetricsPublisher);
this.platformMetricsPublisher.refreshClient();

// NOTE: providerCredentials and providerLogGroupName are null/not null in
// sync.
// Both are required parameters when LoggingConfig (optional) is provided when
Expand Down Expand Up @@ -210,19 +197,6 @@ public void processRequest(final InputStream inputStream, final OutputStream out
// deserialize incoming payload to modeled request
request = this.serializer.deserialize(input, typeReference);
handlerResponse = processInvocation(rawInput, request);
} catch (final ValidationException e) {
String message;
String fullExceptionMessage = ValidationException.buildFullExceptionMessage(e);
if (!StringUtils.isEmpty(fullExceptionMessage)) {
message = String.format("Model validation failed (%s)", fullExceptionMessage);
} else {
message = "Model validation failed with unknown cause.";
}

handlerResponse = ProgressEvent.defaultFailureHandler(new TerminalException(message, e),
HandlerErrorCode.InvalidRequest);
publishExceptionMetric(request != null ? request.getActionInvocationPoint() : null, e,
HandlerErrorCode.InvalidRequest);
} catch (final Throwable e) {
// Exceptions are wrapped as a consistent error response to the caller (i.e;
// CloudFormation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@ public HookExecutableWrapper() {
public HookExecutableWrapper(final CredentialsProvider providerCredentialsProvider,
final CloudWatchLogPublisher providerEventsLogger,
final LogPublisher platformEventsLogger,
final MetricsPublisher platformMetricsPublisher,
final MetricsPublisher providerMetricsPublisher,
final SchemaValidator validator,
final Serializer serializer,
final SdkHttpClient httpClient,
final Cipher cipher) {
super(providerCredentialsProvider, providerEventsLogger, platformEventsLogger, platformMetricsPublisher,
providerMetricsPublisher, validator, serializer, httpClient, cipher);
super(providerCredentialsProvider, providerEventsLogger, platformEventsLogger, providerMetricsPublisher, validator,
serializer, httpClient, cipher);
}

public void handleRequest(final InputStream inputStream, final OutputStream outputStream) throws IOException,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ public HookLambdaWrapper() {
public HookLambdaWrapper(final CredentialsProvider providerCredentialsProvider,
final CloudWatchLogPublisher providerEventsLogger,
final LogPublisher platformEventsLogger,
final MetricsPublisher platformMetricsPublisher,
final MetricsPublisher providerMetricsPublisher,
final SchemaValidator validator,
final Serializer serializer,
final SdkHttpClient httpClient,
final Cipher cipher) {
super(providerCredentialsProvider, providerEventsLogger, platformEventsLogger, platformMetricsPublisher,
providerMetricsPublisher, validator, serializer, httpClient, cipher);
super(providerCredentialsProvider, providerEventsLogger, platformEventsLogger, providerMetricsPublisher, validator,
serializer, httpClient, cipher);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashSet;
import software.amazon.awssdk.core.SdkSystemSetting;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.cloudwatch.CloudWatchClient;
import software.amazon.awssdk.services.cloudwatch.model.Dimension;
import software.amazon.awssdk.services.cloudwatch.model.MetricDatum;
Expand All @@ -32,9 +30,7 @@
import software.amazon.cloudformation.proxy.Logger;

public class HookMetricsPublisherImpl extends MetricsPublisher {
private static final String DEFAULT_REGION = "us-east-1";

private CloudWatchProvider cloudWatchProvider;
private final CloudWatchProvider cloudWatchProvider;
private Logger loggerProxy;
private String awsAccountId;
private CloudWatchClient cloudWatchClient;
Expand All @@ -49,20 +45,9 @@ public HookMetricsPublisherImpl(final CloudWatchProvider cloudWatchProvider,
this.awsAccountId = awsAccountId;
}

public HookMetricsPublisherImpl(final Logger loggerProxy,
final String awsAccountId,
final String hookTypeName) {
super(hookTypeName);
this.loggerProxy = loggerProxy;
this.awsAccountId = awsAccountId;
this.cloudWatchClient = createClient();
}

@Override
public void refreshClient() {
if (cloudWatchProvider != null) {
this.cloudWatchClient = cloudWatchProvider.get();
}
this.cloudWatchClient = cloudWatchProvider.get();
}

private String getHookTypeName() {
Expand Down Expand Up @@ -166,9 +151,4 @@ private void log(final String message) {
}
}

private CloudWatchClient createClient() {
final String region = SdkSystemSetting.AWS_REGION.getStringValue().map(Object::toString).orElse(DEFAULT_REGION);
return CloudWatchClient.builder().region(Region.of(region)).build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ public class HookExecutableWrapperOverride extends HookExecutableWrapper<TestMod
public HookExecutableWrapperOverride(final CredentialsProvider providerLoggingCredentialsProvider,
final LogPublisher platformEventsLogger,
final CloudWatchLogPublisher providerEventsLogger,
final MetricsPublisher platformMetricsPublisher,
final MetricsPublisher providerMetricsPublisher,
final SchemaValidator validator,
final SdkHttpClient httpClient,
final Cipher cipher) {
super(providerLoggingCredentialsProvider, providerEventsLogger, platformEventsLogger, platformMetricsPublisher,
providerMetricsPublisher, validator, new Serializer(), httpClient, cipher);
super(providerLoggingCredentialsProvider, providerEventsLogger, platformEventsLogger, providerMetricsPublisher, validator,
new Serializer(), httpClient, cipher);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,8 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.fasterxml.jackson.core.type.TypeReference;
import java.io.ByteArrayOutputStream;
import java.io.File;
Expand All @@ -31,7 +26,6 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.time.Instant;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -61,9 +55,6 @@ public class HookExecutableWrapperTest {
@Mock
private CredentialsProvider providerLoggingCredentialsProvider;

@Mock
private MetricsPublisher platformMetricsPublisher;

@Mock
private MetricsPublisher providerMetricsPublisher;

Expand All @@ -90,8 +81,8 @@ public class HookExecutableWrapperTest {
@BeforeEach
public void initWrapper() {
wrapper = new HookExecutableWrapperOverride(providerLoggingCredentialsProvider, platformEventsLogger,
providerEventsLogger, platformMetricsPublisher, providerMetricsPublisher,
validator, httpClient, cipher);
providerEventsLogger, providerMetricsPublisher, validator, httpClient,
cipher);
}

private static InputStream loadRequestStream(final String fileName) {
Expand All @@ -107,7 +98,6 @@ private static InputStream loadRequestStream(final String fileName) {

private void verifyInitialiseRuntime() {
verify(providerLoggingCredentialsProvider).setCredentials(any(Credentials.class));
verify(platformMetricsPublisher).refreshClient();
verify(providerMetricsPublisher).refreshClient();
}

Expand Down Expand Up @@ -151,15 +141,6 @@ public void invokeHandler_CompleteSynchronously_returnsSuccess(final String requ
// verify initialiseRuntime was called and initialised dependencies
verifyInitialiseRuntime();

// all metrics should be published, once for a single invocation
verify(platformMetricsPublisher, times(1)).publishInvocationMetric(any(Instant.class), eq(invocationPoint));
verify(platformMetricsPublisher, times(1)).publishDurationMetric(any(Instant.class), eq(invocationPoint), anyLong());
verify(platformMetricsPublisher, times(1)).publishExceptionByErrorCodeAndCountBulkMetrics(any(Instant.class),
any(HookInvocationPoint.class), isNull());

// validation failure metric should not be published
verifyNoMoreInteractions(platformMetricsPublisher);

// verify output response
verifyHandlerResponse(out,
HookProgressEvent.<TestContext>builder().clientRequestToken("123456").hookStatus(HookStatus.SUCCESS).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ public class HookLambdaWrapperOverride extends HookLambdaWrapper<TestModel, Test
public HookLambdaWrapperOverride(final CredentialsProvider providerLoggingCredentialsProvider,
final LogPublisher platformEventsLogger,
final CloudWatchLogPublisher providerEventsLogger,
final MetricsPublisher platformMetricsPublisher,
final MetricsPublisher providerMetricsPublisher,
final SchemaValidator validator,
final SdkHttpClient httpClient,
final Cipher cipher) {
super(providerLoggingCredentialsProvider, providerEventsLogger, platformEventsLogger, platformMetricsPublisher,
providerMetricsPublisher, validator, new Serializer(), httpClient, cipher);
super(providerLoggingCredentialsProvider, providerEventsLogger, platformEventsLogger, providerMetricsPublisher, validator,
new Serializer(), httpClient, cipher);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.LambdaLogger;
import com.fasterxml.jackson.core.type.TypeReference;
Expand All @@ -34,7 +29,6 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.time.Instant;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -64,9 +58,6 @@ public class HookLambdaWrapperTest {
@Mock
private CredentialsProvider providerLoggingCredentialsProvider;

@Mock
private MetricsPublisher platformMetricsPublisher;

@Mock
private MetricsPublisher providerMetricsPublisher;

Expand Down Expand Up @@ -96,8 +87,7 @@ public class HookLambdaWrapperTest {
@BeforeEach
public void initWrapper() {
wrapper = new HookLambdaWrapperOverride(providerLoggingCredentialsProvider, platformEventsLogger, providerEventsLogger,
platformMetricsPublisher, providerMetricsPublisher, validator, httpClient,
cipher);
providerMetricsPublisher, validator, httpClient, cipher);
}

private static InputStream loadRequestStream(final String fileName) {
Expand All @@ -121,7 +111,6 @@ private Context getLambdaContext() {

private void verifyInitialiseRuntime() {
verify(providerLoggingCredentialsProvider).setCredentials(any(Credentials.class));
verify(platformMetricsPublisher).refreshClient();
verify(providerMetricsPublisher).refreshClient();
}

Expand Down Expand Up @@ -166,15 +155,6 @@ public void invokeHandler_CompleteSynchronously_returnsSuccess(final String requ
// verify initialiseRuntime was called and initialised dependencies
verifyInitialiseRuntime();

// all metrics should be published, once for a single invocation
verify(platformMetricsPublisher, times(1)).publishInvocationMetric(any(Instant.class), eq(invocationPoint));
verify(platformMetricsPublisher, times(1)).publishDurationMetric(any(Instant.class), eq(invocationPoint), anyLong());
verify(platformMetricsPublisher, times(1)).publishExceptionByErrorCodeAndCountBulkMetrics(any(Instant.class),
any(HookInvocationPoint.class), isNull());

// validation failure metric should not be published
verifyNoMoreInteractions(platformMetricsPublisher);

// verify output response
verifyHandlerResponse(out,
HookProgressEvent.<TestContext>builder().clientRequestToken("123456").hookStatus(HookStatus.SUCCESS).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ public HookWrapperOverride(final LogPublisher platformEventsLogger,
public HookWrapperOverride(final CredentialsProvider providerLoggingCredentialsProvider,
final LogPublisher platformEventsLogger,
final CloudWatchLogPublisher providerEventsLogger,
final MetricsPublisher platformMetricsPublisher,
final MetricsPublisher providerMetricsPublisher,
final SchemaValidator validator,
final SdkHttpClient httpClient,
final Cipher cipher) {
super(providerLoggingCredentialsProvider, providerEventsLogger, platformEventsLogger, platformMetricsPublisher,
providerMetricsPublisher, validator, new Serializer(), httpClient, cipher);
super(providerLoggingCredentialsProvider, providerEventsLogger, platformEventsLogger, providerMetricsPublisher, validator,
new Serializer(), httpClient, cipher);
}

@Override
Expand Down
Loading