Skip to content

fix: do not truncate encrypted streams when offset is greater than zero #2113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ public FrameEncryptionHandler(
* </ol>
*
* @param in the input byte array.
* @param inOff the offset into the in array where the data to be encrypted starts.
* @param inLen the number of bytes to be encrypted.
* @param off the offset into the in array where the data to be encrypted starts.
* @param len the number of bytes to be encrypted.
* @param out the output buffer the encrypted bytes go into.
* @param outOff the offset into the output byte array the encrypted data starts at.
* @return the number of bytes written to out and processed
Expand All @@ -95,13 +95,13 @@ public ProcessingSummary processBytes(
int actualOutLen = 0;

int size = len;
int offset = off;
int processedBytes = 0;
while (size > 0) {
final int currentFrameCapacity = frameSize_ - bytesToFrameLen_;
// bind size to the capacity of the current frame
size = Math.min(currentFrameCapacity, size);

System.arraycopy(in, offset, bytesToFrame_, bytesToFrameLen_, size);
System.arraycopy(in, off + processedBytes, bytesToFrame_, bytesToFrameLen_, size);
bytesToFrameLen_ += size;

// check if there is enough bytes to create a frame
Expand All @@ -113,10 +113,10 @@ public ProcessingSummary processBytes(
bytesToFrameLen_ = 0;
}

// update offset by the size of bytes being encrypted.
offset += size;
// update size to the remaining bytes starting at offset.
size = len - offset;
// add the size of this frame to processedBytes
processedBytes += size;
// remaining size is original len minus processedBytes
size = len - processedBytes;
}

return new ProcessingSummary(actualOutLen, len);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,29 @@

import com.amazonaws.encryptionsdk.AwsCrypto;
import com.amazonaws.encryptionsdk.CryptoAlgorithm;
import com.amazonaws.encryptionsdk.CryptoInputStream;
import com.amazonaws.encryptionsdk.CryptoOutputStream;
import com.amazonaws.encryptionsdk.CryptoResult;
import com.amazonaws.encryptionsdk.TestUtils;
import com.amazonaws.encryptionsdk.model.CipherFrameHeaders;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Collections;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.bouncycastle.util.encoders.Hex;
import org.junit.Before;
import org.junit.Test;
import software.amazon.awssdk.utils.StringUtils;
import software.amazon.cryptography.materialproviders.IKeyring;
import software.amazon.cryptography.materialproviders.MaterialProviders;
import software.amazon.cryptography.materialproviders.model.AesWrappingAlg;
import software.amazon.cryptography.materialproviders.model.CreateRawAesKeyringInput;
import software.amazon.cryptography.materialproviders.model.MaterialProvidersConfig;

public class FrameEncryptionHandlerTest {
private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG;
Expand Down Expand Up @@ -117,4 +132,86 @@ private void assertHeaderNonce(byte[] expectedNonce, byte[] buf) {
private void generateTestBlock(byte[] buf) {
frameEncryptionHandler_.processBytes(new byte[frameSize_], 0, frameSize_, buf, 0);
}

/**
* This isn't a unit test, but it reproduces a bug in the FrameEncryptionHandler where the stream
* would be truncated when the offset is >0
*
* @throws Exception
*/
@Test
public void testStreamTruncation() throws Exception {
// Initialize AES key and keyring
SecureRandom rnd = new SecureRandom();
byte[] rawKey = new byte[16];
rnd.nextBytes(rawKey);
SecretKeySpec cryptoKey = new SecretKeySpec(rawKey, "AES");
MaterialProviders materialProviders =
MaterialProviders.builder()
.MaterialProvidersConfig(MaterialProvidersConfig.builder().build())
.build();
CreateRawAesKeyringInput keyringInput =
CreateRawAesKeyringInput.builder()
.wrappingKey(ByteBuffer.wrap(cryptoKey.getEncoded()))
.keyNamespace("Example")
.keyName("RandomKey")
.wrappingAlg(AesWrappingAlg.ALG_AES128_GCM_IV12_TAG16)
.build();
IKeyring keyring = materialProviders.CreateRawAesKeyring(keyringInput);
AwsCrypto crypto = AwsCrypto.standard();

String testDataString = StringUtils.repeat("Hello, World! ", 5_000);

int startOffset = 100; // The data will start from this offset
byte[] inputDataWithOffset = new byte[10_000];
// the length of the actual data
int dataLength = inputDataWithOffset.length - startOffset;
// copy some data, starting at the startOffset
// so the first |startOffset| bytes are 0s
System.arraycopy(
testDataString.getBytes(StandardCharsets.UTF_8),
0,
inputDataWithOffset,
startOffset,
dataLength);
// decryptData (non-streaming) doesn't know about the offset
// it will strip out the original 0s
byte[] expectedOutput = new byte[10_000 - startOffset];
System.arraycopy(
testDataString.getBytes(StandardCharsets.UTF_8), 0, expectedOutput, 0, dataLength);

// Encrypt the data
byte[] encryptedData;
try (ByteArrayOutputStream os = new ByteArrayOutputStream()) {
try (CryptoOutputStream cryptoOutput =
crypto.createEncryptingStream(keyring, os, Collections.emptyMap())) {
cryptoOutput.write(inputDataWithOffset, startOffset, dataLength);
}
encryptedData = os.toByteArray();
}

// Check non-streaming decrypt
CryptoResult<byte[], ?> nonStreamDecrypt = crypto.decryptData(keyring, encryptedData);
assertEquals(dataLength, nonStreamDecrypt.getResult().length);
assertArrayEquals(expectedOutput, nonStreamDecrypt.getResult());

// Check streaming decrypt
int decryptedLength = 0;
byte[] decryptedData = new byte[inputDataWithOffset.length];
try (ByteArrayInputStream is = new ByteArrayInputStream(encryptedData);
CryptoInputStream cryptoInput = crypto.createDecryptingStream(keyring, is)) {
int offset = startOffset;
do {
int bytesRead = cryptoInput.read(decryptedData, offset, decryptedData.length - offset);
if (bytesRead <= 0) {
break; // End of stream
}
offset += bytesRead;
decryptedLength += bytesRead;
} while (true);
}
assertEquals(dataLength, decryptedLength);
// These arrays will be offset, i.e. the first |startOffset| bytes are 0s
assertArrayEquals(inputDataWithOffset, decryptedData);
}
}