diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandler.java index d9fc7f63..be2bd123 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandler.java @@ -81,8 +81,8 @@ public FrameEncryptionHandler( * * * @param in the input byte array. - * @param inOff the offset into the in array where the data to be encrypted starts. - * @param inLen the number of bytes to be encrypted. + * @param off the offset into the in array where the data to be encrypted starts. + * @param len the number of bytes to be encrypted. * @param out the output buffer the encrypted bytes go into. * @param outOff the offset into the output byte array the encrypted data starts at. * @return the number of bytes written to out and processed @@ -95,13 +95,13 @@ public ProcessingSummary processBytes( int actualOutLen = 0; int size = len; - int offset = off; + int processedBytes = 0; while (size > 0) { final int currentFrameCapacity = frameSize_ - bytesToFrameLen_; // bind size to the capacity of the current frame size = Math.min(currentFrameCapacity, size); - System.arraycopy(in, offset, bytesToFrame_, bytesToFrameLen_, size); + System.arraycopy(in, off + processedBytes, bytesToFrame_, bytesToFrameLen_, size); bytesToFrameLen_ += size; // check if there is enough bytes to create a frame @@ -113,10 +113,10 @@ public ProcessingSummary processBytes( bytesToFrameLen_ = 0; } - // update offset by the size of bytes being encrypted. - offset += size; - // update size to the remaining bytes starting at offset. - size = len - offset; + // add the size of this frame to processedBytes + processedBytes += size; + // remaining size is original len minus processedBytes + size = len - processedBytes; } return new ProcessingSummary(actualOutLen, len); diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerTest.java index da6bf08f..542ca995 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerTest.java @@ -19,14 +19,29 @@ import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.CryptoInputStream; +import com.amazonaws.encryptionsdk.CryptoOutputStream; +import com.amazonaws.encryptionsdk.CryptoResult; import com.amazonaws.encryptionsdk.TestUtils; import com.amazonaws.encryptionsdk.model.CipherFrameHeaders; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.util.Collections; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; import org.bouncycastle.util.encoders.Hex; import org.junit.Before; import org.junit.Test; +import software.amazon.awssdk.utils.StringUtils; +import software.amazon.cryptography.materialproviders.IKeyring; +import software.amazon.cryptography.materialproviders.MaterialProviders; +import software.amazon.cryptography.materialproviders.model.AesWrappingAlg; +import software.amazon.cryptography.materialproviders.model.CreateRawAesKeyringInput; +import software.amazon.cryptography.materialproviders.model.MaterialProvidersConfig; public class FrameEncryptionHandlerTest { private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; @@ -117,4 +132,86 @@ private void assertHeaderNonce(byte[] expectedNonce, byte[] buf) { private void generateTestBlock(byte[] buf) { frameEncryptionHandler_.processBytes(new byte[frameSize_], 0, frameSize_, buf, 0); } + + /** + * This isn't a unit test, but it reproduces a bug in the FrameEncryptionHandler where the stream + * would be truncated when the offset is >0 + * + * @throws Exception + */ + @Test + public void testStreamTruncation() throws Exception { + // Initialize AES key and keyring + SecureRandom rnd = new SecureRandom(); + byte[] rawKey = new byte[16]; + rnd.nextBytes(rawKey); + SecretKeySpec cryptoKey = new SecretKeySpec(rawKey, "AES"); + MaterialProviders materialProviders = + MaterialProviders.builder() + .MaterialProvidersConfig(MaterialProvidersConfig.builder().build()) + .build(); + CreateRawAesKeyringInput keyringInput = + CreateRawAesKeyringInput.builder() + .wrappingKey(ByteBuffer.wrap(cryptoKey.getEncoded())) + .keyNamespace("Example") + .keyName("RandomKey") + .wrappingAlg(AesWrappingAlg.ALG_AES128_GCM_IV12_TAG16) + .build(); + IKeyring keyring = materialProviders.CreateRawAesKeyring(keyringInput); + AwsCrypto crypto = AwsCrypto.standard(); + + String testDataString = StringUtils.repeat("Hello, World! ", 5_000); + + int startOffset = 100; // The data will start from this offset + byte[] inputDataWithOffset = new byte[10_000]; + // the length of the actual data + int dataLength = inputDataWithOffset.length - startOffset; + // copy some data, starting at the startOffset + // so the first |startOffset| bytes are 0s + System.arraycopy( + testDataString.getBytes(StandardCharsets.UTF_8), + 0, + inputDataWithOffset, + startOffset, + dataLength); + // decryptData (non-streaming) doesn't know about the offset + // it will strip out the original 0s + byte[] expectedOutput = new byte[10_000 - startOffset]; + System.arraycopy( + testDataString.getBytes(StandardCharsets.UTF_8), 0, expectedOutput, 0, dataLength); + + // Encrypt the data + byte[] encryptedData; + try (ByteArrayOutputStream os = new ByteArrayOutputStream()) { + try (CryptoOutputStream cryptoOutput = + crypto.createEncryptingStream(keyring, os, Collections.emptyMap())) { + cryptoOutput.write(inputDataWithOffset, startOffset, dataLength); + } + encryptedData = os.toByteArray(); + } + + // Check non-streaming decrypt + CryptoResult nonStreamDecrypt = crypto.decryptData(keyring, encryptedData); + assertEquals(dataLength, nonStreamDecrypt.getResult().length); + assertArrayEquals(expectedOutput, nonStreamDecrypt.getResult()); + + // Check streaming decrypt + int decryptedLength = 0; + byte[] decryptedData = new byte[inputDataWithOffset.length]; + try (ByteArrayInputStream is = new ByteArrayInputStream(encryptedData); + CryptoInputStream cryptoInput = crypto.createDecryptingStream(keyring, is)) { + int offset = startOffset; + do { + int bytesRead = cryptoInput.read(decryptedData, offset, decryptedData.length - offset); + if (bytesRead <= 0) { + break; // End of stream + } + offset += bytesRead; + decryptedLength += bytesRead; + } while (true); + } + assertEquals(dataLength, decryptedLength); + // These arrays will be offset, i.e. the first |startOffset| bytes are 0s + assertArrayEquals(inputDataWithOffset, decryptedData); + } }