diff --git a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/AzureSqlKeyCryptographer.cs b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/AzureSqlKeyCryptographer.cs index 0dff6ac786..c4e9eb8396 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/AzureSqlKeyCryptographer.cs +++ b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/AzureSqlKeyCryptographer.cs @@ -7,12 +7,12 @@ using Azure.Security.KeyVault.Keys.Cryptography; using System; using System.Collections.Concurrent; -using System.Threading.Tasks; +using System.Threading; using static Azure.Security.KeyVault.Keys.Cryptography.SignatureAlgorithm; namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider { - internal class AzureSqlKeyCryptographer + internal sealed class AzureSqlKeyCryptographer : IDisposable { /// /// TokenCredential to be used with the KeyClient @@ -25,16 +25,14 @@ internal class AzureSqlKeyCryptographer private readonly ConcurrentDictionary _keyClientDictionary = new(); /// - /// Holds references to the fetch key tasks and maps them to their corresponding Azure Key Vault Key Identifier (URI). - /// These tasks will be used for returning the key in the event that the fetch task has not finished depositing the - /// key into the key dictionary. + /// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI). /// - private readonly ConcurrentDictionary>> _keyFetchTaskDictionary = new(); + private readonly ConcurrentDictionary _keyDictionary = new(); /// - /// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI). + /// SemaphoreSlim to ensure thread safety when accessing the key dictionary or making network calls to Azure Key Vault to fetch keys. /// - private readonly ConcurrentDictionary _keyDictionary = new(); + private SemaphoreSlim _keyDictionarySemaphore = new(1, 1); /// /// Holds references to the Azure Key Vault CryptographyClient objects and maps them to their corresponding Azure Key Vault Key Identifier (URI). @@ -50,20 +48,44 @@ internal AzureSqlKeyCryptographer(TokenCredential tokenCredential) TokenCredential = tokenCredential; } + /// + /// Disposes the SemaphoreSlim used for thread safety. + /// + public void Dispose() + { + _keyDictionarySemaphore.Dispose(); + } + /// /// Adds the key, specified by the Key Identifier URI, to the cache. + /// Validates the key type and fetches the key from Azure Key Vault if it is not already cached. /// /// internal void AddKey(string keyIdentifierUri) { - if (TheKeyHasNotBeenCached(keyIdentifierUri)) + // Allow only one thread to proceed to ensure thread safety + // as we will need to fetch key information from Azure Key Vault if the key is not found in cache. + _keyDictionarySemaphore.Wait(); + + try { - ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion); - CreateKeyClient(vaultUri); - FetchKey(vaultUri, keyName, keyVersion, keyIdentifierUri); - } + if (!_keyDictionary.ContainsKey(keyIdentifierUri)) + { + ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion); + + // Fetch the KeyClient for the Key vault URI. + KeyClient keyClient = GetOrCreateKeyClient(vaultUri); + + // Fetch the key from Azure Key Vault. + KeyVaultKey key = FetchKeyFromKeyVault(keyClient, keyName, keyVersion); - bool TheKeyHasNotBeenCached(string k) => !_keyDictionary.ContainsKey(k) && !_keyFetchTaskDictionary.ContainsKey(k); + _keyDictionary.AddOrUpdate(keyIdentifierUri, key, (k, v) => key); + } + } + finally + { + _keyDictionarySemaphore.Release(); + } } /// @@ -75,18 +97,12 @@ internal KeyVaultKey GetKey(string keyIdentifierUri) { if (_keyDictionary.TryGetValue(keyIdentifierUri, out KeyVaultKey key)) { - AKVEventSource.Log.TryTraceEvent("Fetched master key from cache"); + AKVEventSource.Log.TryTraceEvent("Fetched key name={0} from cache", key.Name); return key; } - if (_keyFetchTaskDictionary.TryGetValue(keyIdentifierUri, out Task> task)) - { - AKVEventSource.Log.TryTraceEvent("New Master key fetched."); - return Task.Run(() => task).GetAwaiter().GetResult(); - } - // Not a public exception - not likely to occur. - AKVEventSource.Log.TryTraceEvent("Master key not found."); + AKVEventSource.Log.TryTraceEvent("Key not found; URI={0}", keyIdentifierUri); throw ADP.MasterKeyNotFound(keyIdentifierUri); } @@ -95,10 +111,7 @@ internal KeyVaultKey GetKey(string keyIdentifierUri) /// /// The key vault key identifier URI /// - internal int GetKeySize(string keyIdentifierUri) - { - return GetKey(keyIdentifierUri).Key.N.Length; - } + internal int GetKeySize(string keyIdentifierUri) => GetKey(keyIdentifierUri).Key.N.Length; /// /// Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL. @@ -142,41 +155,50 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri) CryptographyClient cryptographyClient = new(GetKey(keyIdentifierUri).Id, TokenCredential); _cryptoClientDictionary.TryAdd(keyIdentifierUri, cryptographyClient); - return cryptographyClient; } /// - /// + /// Fetches the column encryption key from the Azure Key Vault. /// - /// The Azure Key Vault URI + /// The KeyClient instance /// The name of the Azure Key Vault key /// The version of the Azure Key Vault key - /// The Azure Key Vault key identifier - private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string keyResourceUri) + private KeyVaultKey FetchKeyFromKeyVault(KeyClient keyClient, string keyName, string keyVersion) { - Task> fetchKeyTask = FetchKeyFromKeyVault(vaultUri, keyName, keyVersion); - _keyFetchTaskDictionary.AddOrUpdate(keyResourceUri, fetchKeyTask, (k, v) => fetchKeyTask); + AKVEventSource.Log.TryTraceEvent("Fetching key name={0}", keyName); - fetchKeyTask - .ContinueWith(k => ValidateRsaKey(k.GetAwaiter().GetResult())) - .ContinueWith(k => _keyDictionary.AddOrUpdate(keyResourceUri, k.GetAwaiter().GetResult(), (key, v) => k.GetAwaiter().GetResult())); + Azure.Response keyResponse = keyClient?.GetKey(keyName, keyVersion); - Task.Run(() => fetchKeyTask); + // Handle the case where the key response is null or contains an error + // This can happen if the key does not exist or if there is an issue with the KeyClient. + // In such cases, we log the error and throw an exception. + if (keyResponse == null || keyResponse.Value == null || keyResponse.GetRawResponse().IsError) + { + AKVEventSource.Log.TryTraceEvent("Get Key failed to fetch Key from Azure Key Vault for key {0}, version {1}", keyName, keyVersion); + if (keyResponse?.GetRawResponse() is Azure.Response response) + { + AKVEventSource.Log.TryTraceEvent("Response status {0} : {1}", response.Status, response.ReasonPhrase); + } + throw ADP.GetKeyFailed(keyName); + } + + KeyVaultKey key = keyResponse.Value; + + // Validate that the key is of type RSA + key = ValidateRsaKey(key); + return key; } /// - /// Looks up the KeyClient object by it's URI and then fetches the key by name. + /// Gets or creates a KeyClient for the specified Azure Key Vault URI. /// - /// The Azure Key Vault URI - /// Then name of the key - /// Then version of the key + /// Key Identifier URL /// - private Task> FetchKeyFromKeyVault(Uri vaultUri, string keyName, string keyVersion) + private KeyClient GetOrCreateKeyClient(Uri vaultUri) { - _keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient); - AKVEventSource.Log.TryTraceEvent("Fetching requested master key: {0}", keyName); - return keyClient?.GetKeyAsync(keyName, keyVersion); + return _keyClientDictionary.GetOrAdd( + vaultUri, (_) => new KeyClient(vaultUri, TokenCredential)); } /// @@ -184,7 +206,7 @@ private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string ke /// /// /// - private KeyVaultKey ValidateRsaKey(KeyVaultKey key) + private static KeyVaultKey ValidateRsaKey(KeyVaultKey key) { if (key.KeyType != KeyType.Rsa && key.KeyType != KeyType.RsaHsm) { @@ -195,18 +217,6 @@ private KeyVaultKey ValidateRsaKey(KeyVaultKey key) return key; } - /// - /// Instantiates and adds a KeyClient to the KeyClient dictionary - /// - /// The Azure Key Vault URI - private void CreateKeyClient(Uri vaultUri) - { - if (!_keyClientDictionary.ContainsKey(vaultUri)) - { - _keyClientDictionary.TryAdd(vaultUri, new KeyClient(vaultUri, TokenCredential)); - } - } - /// /// Validates and parses the Azure Key Vault URI and key name. /// @@ -214,7 +224,7 @@ private void CreateKeyClient(Uri vaultUri) /// The Azure Key Vault URI /// The name of the key /// The version of the key - private void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion) + private static void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion) { Uri masterKeyPathUri = new(masterKeyPath); vaultUri = new Uri(masterKeyPathUri.GetLeftPart(UriPartial.Authority)); diff --git a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/LocalCache.cs b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/LocalCache.cs index 3e17f5d951..7fbffe0ae6 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/LocalCache.cs +++ b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/LocalCache.cs @@ -2,8 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.Extensions.Caching.Memory; using System; +using Microsoft.Extensions.Caching.Memory; using static System.Math; namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider @@ -92,6 +92,7 @@ internal TValue GetOrCreate(TKey key, Func createItem) /// /// Determines whether the LocalCache contains the specified key. + /// Used in unit tests to verify that the cache contains the expected entries. /// /// /// diff --git a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs index eb8c8d77c4..d4740a1183 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs +++ b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs @@ -4,6 +4,7 @@ using System; using System.Text; +using System.Threading; using Azure.Core; using Azure.Security.KeyVault.Keys.Cryptography; using static Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider.Validator; @@ -55,6 +56,8 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt private readonly static KeyWrapAlgorithm s_keyWrapAlgorithm = KeyWrapAlgorithm.RsaOaep; + private SemaphoreSlim _cacheSemaphore = new(1, 1); + /// /// List of Trusted Endpoints /// @@ -69,7 +72,7 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt /// /// A cache for storing the results of signature verification of column master key metadata. /// - private readonly LocalCache, bool> _columnMasterKeyMetadataSignatureVerificationCache = + private readonly LocalCache, bool> _columnMasterKeyMetadataSignatureVerificationCache = new(maxSizeLimit: 2000) { TimeToLive = TimeSpan.FromDays(10) }; /// @@ -230,7 +233,7 @@ byte[] DecryptEncryptionKey() // Get ciphertext byte[] cipherText = new byte[cipherTextLength]; Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength); - + currentIndex += cipherTextLength; // Get signature @@ -394,17 +397,10 @@ private byte[] CompileMasterKeyMetadata(string masterKeyPath, bool allowEnclaveC /// An array of bytes to convert. /// A string of hexadecimal characters /// - /// Produces a string of hexadecimal character pairs preceded with "0x", where each pair represents the corresponding element in value; for example, "0x7F2C4A00". + /// Produces a string of hexadecimal character pairs preceded with "0x", where each pair represents the corresponding element in source; for example, "0x7F2C4A00". /// private string ToHexString(byte[] source) - { - if (source is null) - { - return null; - } - - return "0x" + BitConverter.ToString(source).Replace("-", ""); - } + => source is null ? null : "0x" + BitConverter.ToString(source).Replace("-", ""); /// /// Returns the cached decrypted column encryption key, or unwraps the encrypted column encryption key if not present. @@ -415,8 +411,21 @@ private string ToHexString(byte[] source) /// /// /// - private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func createItem) - => _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem); + private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func createItem) + { + // Allow only one thread to access the cache at a time. + _cacheSemaphore.Wait(); + + try + { + return _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem); + } + finally + { + // Release the semaphore to allow other threads to access the cache. + _cacheSemaphore.Release(); + } + } /// /// Returns the cached signature verification result, or proceeds to verify if not present. diff --git a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.Designer.cs index fc5d88930a..c3b9ce9104 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.Designer.cs @@ -8,6 +8,8 @@ // //------------------------------------------------------------------------------ +using System.Globalization; + namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider { /// @@ -88,6 +90,17 @@ internal static string EmptyArgumentInternal } } + /// + /// Looks up a localized string similar to: Failed to fetch key from Azure Key Vault. Key: {0}. + /// + internal static string GetKeyFailed + { + get + { + return ResourceManager.GetString("GetKeyFailed", resourceCulture); + } + } + /// /// Looks up a localized string similar to Signed hash length does not match the RSA key size.. /// @@ -199,7 +212,18 @@ internal static string InvalidSignatureTemplate } /// - /// Looks up a localized string similar to Invalid trusted endpoint specified: '{0}'; a trusted endpoint must have a value.. + /// Looks up a localized string similar to The key with identifier '{0}' was not found.. + /// + internal static string MasterKeyNotFound + { + get + { + return ResourceManager.GetString("MasterKeyNotFound", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to One or more of the elements in '{0}' are null or empty or consist of only whitespace.. /// internal static string NullOrWhitespaceForEach { diff --git a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.resx b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.resx index 039d1079d5..8775b345ab 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.resx +++ b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Strings.resx @@ -118,13 +118,16 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 - One or more of the elements in {0} are null or empty or consist of only whitespace. + One or more of the elements in '{0}' are null or empty or consist of only whitespace. CipherText length does not match the RSA key size. - Internal error. Empty {0} specified. + Internal error. Empty '{0}' specified. + + + Failed to fetch key from Azure Key Vault. Key: {0}. The key with identifier '{0}' was not found. diff --git a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Utils.cs b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Utils.cs index f71080ffab..0eb4dc1c9b 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Utils.cs +++ b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Utils.cs @@ -86,7 +86,10 @@ internal static ArgumentException NullOrWhitespaceForEach(string name) => new(string.Format(Strings.NullOrWhitespaceForEach, name)); internal static KeyNotFoundException MasterKeyNotFound(string masterKeyPath) => - new(string.Format(CultureInfo.InvariantCulture, Strings.InvalidSignatureTemplate, masterKeyPath)); + new(string.Format(CultureInfo.InvariantCulture, Strings.MasterKeyNotFound, masterKeyPath)); + + internal static KeyNotFoundException GetKeyFailed(string masterKeyPath) => + new(string.Format(CultureInfo.InvariantCulture, Strings.GetKeyFailed, masterKeyPath)); internal static FormatException NonRsaKeyFormat(string keyType) => new(string.Format(CultureInfo.InvariantCulture, Strings.NonRsaKeyTemplate, keyType)); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index 45df50badb..72dec5a28e 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -236,7 +236,7 @@ private SqlConnection(SqlConnection connection) internal static bool TryGetSystemColumnEncryptionKeyStoreProvider(string keyStoreName, out SqlColumnEncryptionKeyStoreProvider provider) { - return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider); + return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider); } /// @@ -276,9 +276,9 @@ internal static List GetColumnEncryptionSystemKeyStoreProvidersNames() { if (s_systemColumnEncryptionKeyStoreProviders.Count > 0) { - return new List(s_systemColumnEncryptionKeyStoreProviders.Keys); + return [.. s_systemColumnEncryptionKeyStoreProviders.Keys]; } - return new List(0); + return []; } /// @@ -291,13 +291,13 @@ internal List GetColumnEncryptionCustomKeyStoreProvidersNames() if (_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0) { - return new List(_customColumnEncryptionKeyStoreProviders.Keys); + return [.. _customColumnEncryptionKeyStoreProviders.Keys]; } if (s_globalCustomColumnEncryptionKeyStoreProviders is not null) { - return new List(s_globalCustomColumnEncryptionKeyStoreProviders.Keys); + return [.. s_globalCustomColumnEncryptionKeyStoreProviders.Keys]; } - return new List(0); + return []; } /// @@ -325,7 +325,9 @@ public static void RegisterColumnEncryptionKeyStoreProviders(IDictionary(cacheLookupKey, encryptionKey, options); + // In case multiple threads reach here at the same time, the first one wins. + // The allocated memory will be reclaimed by Garbage Collector. + MemoryCacheEntryOptions options = new() + { + AbsoluteExpirationRelativeToNow = SqlConnection.ColumnEncryptionKeyCacheTtl + }; + _cache.Set(cacheLookupKey, encryptionKey, options); + } } - } - return encryptionKey; + return encryptionKey; + } + finally + { + // Release the lock to allow other threads to access the cache + _cacheLock.Release(); + } } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ExceptionTestAKVStore.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ExceptionTestAKVStore.cs index 2465633a03..15679176ac 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ExceptionTestAKVStore.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ExceptionTestAKVStore.cs @@ -54,7 +54,7 @@ public void NullEncryptionAlgorithm() public void EmptyColumnEncryptionKey() { Exception ex1 = Assert.Throws(() => _fixture.AkvStoreProvider.EncryptColumnEncryptionKey(_fixture.AkvKeyUrl, MasterKeyEncAlgo, new byte[] { })); - Assert.Matches($@"Internal error. Empty columnEncryptionKey specified.", ex1.Message); + Assert.Matches($@"Internal error. Empty 'columnEncryptionKey' specified.", ex1.Message); } [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAKVSetupAvailable))] @@ -68,7 +68,7 @@ public void NullColumnEncryptionKey() public void EmptyEncryptedColumnEncryptionKey() { Exception ex1 = Assert.Throws(() => _fixture.AkvStoreProvider.DecryptColumnEncryptionKey(_fixture.AkvKeyUrl, MasterKeyEncAlgo, new byte[] { })); - Assert.Matches($@"Internal error. Empty encryptedColumnEncryptionKey specified", ex1.Message); + Assert.Matches($@"Internal error. Empty 'encryptedColumnEncryptionKey' specified", ex1.Message); } [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAKVSetupAvailable))] @@ -250,7 +250,7 @@ public void InvalidTrustedEndpoints(string[] trustedEndpoints) SqlColumnEncryptionAzureKeyVaultProvider azureKeyProvider = new SqlColumnEncryptionAzureKeyVaultProvider( new SqlClientCustomTokenCredential(), trustedEndpoints); }); - Assert.Matches("One or more of the elements in trustedEndpoints are null or empty or consist of only whitespace.", ex.Message); + Assert.Matches("One or more of the elements in 'trustedEndpoints' are null or empty or consist of only whitespace.", ex.Message); } [InlineData(null)] @@ -264,7 +264,7 @@ public void InvalidTrustedEndpoint(string trustedEndpoint) SqlColumnEncryptionAzureKeyVaultProvider azureKeyProvider = new SqlColumnEncryptionAzureKeyVaultProvider( new SqlClientCustomTokenCredential(), trustedEndpoint); }); - Assert.Matches("One or more of the elements in trustedEndpoints are null or empty or consist of only whitespace.", ex.Message); + Assert.Matches("One or more of the elements in 'trustedEndpoints' are null or empty or consist of only whitespace.", ex.Message); } } }