diff --git a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs index fdcd59916ed..bec64ba12af 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs @@ -14,9 +14,11 @@ */ using System; +using System.Buffers; using System.Collections.Generic; using System.ComponentModel; using System.Reflection; +using System.Runtime.CompilerServices; using MongoDB.Bson.IO; using MongoDB.Bson.Serialization.Conventions; using MongoDB.Bson.Serialization.Serializers; @@ -82,7 +84,7 @@ public override TClass Deserialize(BsonDeserializationContext context, BsonDeser { var bsonReader = context.Reader; - if (bsonReader.GetCurrentBsonType() == Bson.BsonType.Null) + if (bsonReader.GetCurrentBsonType() == BsonType.Null) { bsonReader.ReadNull(); return default(TClass); @@ -149,7 +151,9 @@ public TClass DeserializeClass(BsonDeserializationContext context) var discriminatorConvention = _classMap.GetDiscriminatorConvention(); var allMemberMaps = _classMap.AllMemberMaps; var extraElementsMemberMapIndex = _classMap.ExtraElementsMemberMapIndex; - var memberMapBitArray = FastMemberMapHelper.GetBitArray(allMemberMaps.Count); + + var (lengthInUInts, useStackAlloc) = FastMemberMapHelper.GetLengthInUInts(allMemberMaps.Count); + using var bitArray = useStackAlloc ? FastMemberMapHelper.GetMembersBitArray(stackalloc uint[lengthInUInts]) : FastMemberMapHelper.GetMembersBitArray(lengthInUInts); bsonReader.ReadStartDocument(); var elementTrie = _classMap.ElementTrie; @@ -193,7 +197,8 @@ public TClass DeserializeClass(BsonDeserializationContext context) DeserializeExtraElementValue(context, values, elementName, memberMap); } } - memberMapBitArray[memberMapIndex >> 5] |= 1U << (memberMapIndex & 31); + + bitArray.SetMemberIndex(memberMapIndex); } else { @@ -221,7 +226,7 @@ public TClass DeserializeClass(BsonDeserializationContext context) { DeserializeExtraElementValue(context, values, elementName, extraElementsMemberMap); } - memberMapBitArray[extraElementsMemberMapIndex >> 5] |= 1U << (extraElementsMemberMapIndex & 31); + bitArray.SetMemberIndex(extraElementsMemberMapIndex); } else if (_classMap.IgnoreExtraElements) { @@ -239,51 +244,38 @@ public TClass DeserializeClass(BsonDeserializationContext context) bsonReader.ReadEndDocument(); // check any members left over that we didn't have elements for (in blocks of 32 elements at a time) - for (var bitArrayIndex = 0; bitArrayIndex < memberMapBitArray.Length; ++bitArrayIndex) + var bitArraySpan = bitArray.Span; + for (var bitArrayIndex = 0; bitArrayIndex < bitArraySpan.Length; bitArrayIndex++) { var memberMapIndex = bitArrayIndex << 5; - var memberMapBlock = ~memberMapBitArray[bitArrayIndex]; // notice that bits are flipped so 1's are now the missing elements + var memberMapBlock = ~bitArraySpan[bitArrayIndex]; // notice that bits are flipped so 1's are now the missing elements // work through this memberMapBlock of 32 elements - while (true) + for (; memberMapBlock != 0 && memberMapIndex < allMemberMaps.Count; memberMapIndex++, memberMapBlock >>= 1) { - // examine missing elements (memberMapBlock is shifted right as we work through the block) - for (; (memberMapBlock & 1) != 0; ++memberMapIndex, memberMapBlock >>= 1) - { - var memberMap = allMemberMaps[memberMapIndex]; - if (memberMap.IsReadOnly) - { - continue; - } - - if (memberMap.IsRequired) - { - var fieldOrProperty = (memberMap.MemberInfo is FieldInfo) ? "field" : "property"; - var message = string.Format( - "Required element '{0}' for {1} '{2}' of class {3} is missing.", - memberMap.ElementName, fieldOrProperty, memberMap.MemberName, _classMap.ClassType.FullName); - throw new FormatException(message); - } + if ((memberMapBlock & 1) == 0) + continue; - if (document != null) - { - memberMap.ApplyDefaultValue(document); - } - else if (memberMap.IsDefaultValueSpecified && !memberMap.IsReadOnly) - { - values[memberMap.ElementName] = memberMap.DefaultValue; - } + var memberMap = allMemberMaps[memberMapIndex]; + if (memberMap.IsReadOnly) + { + continue; } - if (memberMapBlock == 0) + if (memberMap.IsRequired) { - break; + var fieldOrProperty = (memberMap.MemberInfo is FieldInfo) ? "field" : "property"; + throw new FormatException($"Required element '{memberMap.ElementName}' for {fieldOrProperty} '{memberMap.MemberName}' of class {_classMap.ClassType.FullName} is missing."); } - // skip ahead to the next missing element - var leastSignificantBit = FastMemberMapHelper.GetLeastSignificantBit(memberMapBlock); - memberMapIndex += leastSignificantBit; - memberMapBlock >>= leastSignificantBit; + if (document != null) + { + memberMap.ApplyDefaultValue(document); + } + else if (memberMap.IsDefaultValueSpecified && !memberMap.IsReadOnly) + { + values[memberMap.ElementName] = memberMap.DefaultValue; + } } } @@ -335,13 +327,11 @@ public bool GetDocumentId( idGenerator = idMemberMap.IdGenerator; return true; } - else - { - id = null; - idNominalType = null; - idGenerator = null; - return false; - } + + id = null; + idNominalType = null; + idGenerator = null; + return false; } /// @@ -694,48 +684,63 @@ private bool ShouldSerializeDiscriminator(Type nominalType) // nested classes // helper class that implements member map bit array helper functions - private static class FastMemberMapHelper + internal static class FastMemberMapHelper { - public static uint[] GetBitArray(int memberCount) + internal ref struct MembersBitArray() { - var bitArrayOffset = memberCount & 31; - var bitArrayLength = memberCount >> 5; - if (bitArrayOffset == 0) - { - return new uint[bitArrayLength]; - } - var bitArray = new uint[bitArrayLength + 1]; - bitArray[bitArrayLength] = ~0U << bitArrayOffset; // set unused bits to 1 - return bitArray; - } + private readonly ArrayPool _arrayPool; + private readonly Span _bitArray; + private readonly uint[] _rentedBuffer; + private bool _isDisposed = false; - // see http://graphics.stanford.edu/~seander/bithacks.html#ZerosOnRightBinSearch - // also returns 31 if no bits are set; caller must check this case - public static int GetLeastSignificantBit(uint bitBlock) - { - var leastSignificantBit = 1; - if ((bitBlock & 65535) == 0) - { - bitBlock >>= 16; - leastSignificantBit |= 16; - } - if ((bitBlock & 255) == 0) + public MembersBitArray(Span bitArray) : this() { - bitBlock >>= 8; - leastSignificantBit |= 8; + _arrayPool = null; + _bitArray = bitArray; + _rentedBuffer = null; + + _bitArray.Clear(); } - if ((bitBlock & 15) == 0) + + public MembersBitArray(int lengthInUInts, ArrayPool arrayPool) : this() { - bitBlock >>= 4; - leastSignificantBit |= 4; + _arrayPool = arrayPool; + _rentedBuffer = arrayPool.Rent(lengthInUInts); + _bitArray = _rentedBuffer.AsSpan(0, lengthInUInts); + + _bitArray.Clear(); } - if ((bitBlock & 3) == 0) + + public Span Span => _bitArray; + public ArrayPool ArrayPool => _arrayPool; + + public void SetMemberIndex(int memberMapIndex) => + _bitArray[memberMapIndex >> 5] |= 1U << (memberMapIndex & 31); + + public void Dispose() { - bitBlock >>= 2; - leastSignificantBit |= 2; + if (_isDisposed) + return; + + if (_rentedBuffer != null) + { + _arrayPool.Return(_rentedBuffer); + } + _isDisposed = true; } - return leastSignificantBit - (int)(bitBlock & 1); } + + public static (int LengthInUInts, bool UseStackAlloc) GetLengthInUInts(int membersCount) + { + var lengthInUInts = (membersCount + 31) >> 5; + return (lengthInUInts, lengthInUInts <= 8); // Use stackalloc for up to 256 members + } + + public static MembersBitArray GetMembersBitArray(Span span) => + new(span); + + public static MembersBitArray GetMembersBitArray(int lengthInUInts) => + new(lengthInUInts, ArrayPool.Shared); } } } diff --git a/tests/MongoDB.Bson.Tests/Serialization/BsonClassMapSerializerTests.cs b/tests/MongoDB.Bson.Tests/Serialization/BsonClassMapSerializerTests.cs index 649c45e64c0..41e1a226580 100644 --- a/tests/MongoDB.Bson.Tests/Serialization/BsonClassMapSerializerTests.cs +++ b/tests/MongoDB.Bson.Tests/Serialization/BsonClassMapSerializerTests.cs @@ -13,10 +13,18 @@ * limitations under the License. */ +using System; +using System.Buffers; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; using FluentAssertions; using MongoDB.Bson.IO; using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Bson.TestHelpers; +using MongoDB.TestHelpers.XunitExtensions; +using Moq; using Xunit; namespace MongoDB.Bson.Tests.Serialization @@ -39,6 +47,62 @@ static BsonClassMapSerializerTests() } // public methods + [Theory] + [ParameterAttributeData] + public void Deserialize_should_not_throw_when_all_required_elements_present( + [Values(0, 1, 8, 23, 63, 111, 127, 128, 129, 555, 1024, 2500)]int membersCount) + { + var subject = BuildTypeAndGetSerializer("Prop", membersCount); + var properties = Enumerable + .Range(0, membersCount) + .Select(i => $"\"Prop_{i}\" : \"Value_{i}\""); + var json = $"{{{string.Join(",", properties)}}}"; + + using var reader = new JsonReader(json); + var context = BsonDeserializationContext.CreateRoot(reader); + + var obj = subject.Deserialize(context); + + for (var i = 0; i < membersCount; i++) + { + Reflector.GetFieldValue(obj, $"Prop_{i}", BindingFlags.Public | BindingFlags.Instance) + .Should().Be($"Value_{i}"); + } + } + + [Theory] + [InlineData(1, 0)] + [InlineData(8, 0)] + [InlineData(8, 7)] + [InlineData(256, 1)] + [InlineData(256, 255)] + [InlineData(555, 333)] + [InlineData(555, 551)] + [InlineData(555, 554)] + [InlineData(1024, 0)] + [InlineData(1024, 555)] + [InlineData(1024, 992)] + [InlineData(1024, 993)] + [InlineData(1024, 1000)] + [InlineData(1024, 1023)] + public void Deserialize_should_throw_FormatException_when_required_element_is_not_found(int membersCount, int missingMemberIndex) + { + var subject = BuildTypeAndGetSerializer("Prop", membersCount); + var properties = Enumerable + .Range(0, membersCount) + .Except([missingMemberIndex]) + .Select(i => $"\"Prop_{i}\" : \"Value_{i}\""); + var json = $"{{{string.Join(",", properties)}}}"; + + using var reader = new JsonReader(json); + var context = BsonDeserializationContext.CreateRoot(reader); + + var exception = Record.Exception(() => subject.Deserialize(context)); + exception.Should() + .BeOfType() + .Subject.Message.Should().Contain($"Prop_{missingMemberIndex}"); + } + [Fact] public void Deserialize_should_throw_invalidOperationException_when_creator_returns_null() { @@ -125,16 +189,143 @@ public void Equals_with_not_equal_field_should_return_false() result.Should().Be(false); } + [Theory] + [InlineData(0, 0, true)] + [InlineData(1, 1, true)] + [InlineData(2, 1, true)] + [InlineData(32, 1, true)] + [InlineData(33, 2, true)] + [InlineData(256, 8, true)] + [InlineData(257, 9, false)] + public void FastMemberMapHelper_GetMembersBitArrayLength_should_return_correctValue(int memberCount, int expectedLengthInUInts, bool expectedUseStackAlloc) + { + var (lengthInUInts, useStackAlloc) = BsonClassMapSerializer.FastMemberMapHelper.GetLengthInUInts(memberCount); + + lengthInUInts.ShouldBeEquivalentTo(expectedLengthInUInts); + useStackAlloc.ShouldBeEquivalentTo(expectedUseStackAlloc); + } + + [Fact] + public void FastMemberMapHelper_GetMembersBitArray_with_span_should_use_the_provided_span() + { + var backingArray = new uint[] { 1, 2, 3 }; + using var bitArray = BsonClassMapSerializer.FastMemberMapHelper.GetMembersBitArray(backingArray); + + bitArray.Span.ToArray().ShouldBeEquivalentTo(new uint[] { 0, 0, 0 }); + bitArray.ArrayPool.Should().Be(null); + + bitArray.Span[0] = 12; + backingArray.ShouldBeEquivalentTo(new uint[] { 12, 0, 0 }); + } + + [Theory] + [InlineData(3)] + [InlineData(25)] + public void FastMemberMapHelper_GetMembersBitArray_with_length_should_allocate_span(int length) + { + using var bitArray = BsonClassMapSerializer.FastMemberMapHelper.GetMembersBitArray(length); + + bitArray.Span.ToArray().ShouldBeEquivalentTo(Enumerable.Repeat(0, length)); + bitArray.ArrayPool.Should().Be(ArrayPool.Shared); + } + + [Theory] + [InlineData(1)] + [InlineData(2)] + public void FastMemberMapHelper_MembersBitArray_with_arraypool_should_dispose_only_once(int disposeCount) + { + var backingArray = new uint[] { 1, 2, 3 }; + + var mockArrayPool = new Mock>(); + mockArrayPool.Setup(p => p.Rent(backingArray.Length)).Returns(backingArray); + var bitArray = new BsonClassMapSerializer.FastMemberMapHelper.MembersBitArray(backingArray.Length, mockArrayPool.Object); + + for (int i = 0; i < disposeCount; i++) + { + bitArray.Dispose(); + } + + mockArrayPool.Verify(a => a.Return(backingArray, false), Times.Once); + } + + [Theory] + [InlineData(1, 0)] + [InlineData(8, 0)] + [InlineData(8, 7)] + [InlineData(99, 100)] + [InlineData(266, 255)] + [InlineData(544, 0)] + [InlineData(621, 255)] + public void FastMemberMapHelper_GetMembersBitArray_SetMemberIndex_should_set_correct_bit(int membersCount, int memberIndex) + { + var (length, _) = BsonClassMapSerializer.FastMemberMapHelper.GetLengthInUInts(membersCount); + using var bitArray = BsonClassMapSerializer.FastMemberMapHelper.GetMembersBitArray(length); + + var span = bitArray.Span; + var blockIndex = memberIndex >> 5; + var bitIndex = memberIndex & 31; + + bitArray.SetMemberIndex(memberIndex); + + for (var i = 0; i < span.Length; i++) + { + for (int b = 0; b < 32; b++) + { + var bit = span[i] & (1U << b); + + if (i == blockIndex && b == bitIndex) + { + bit.Should().Be(1U << b); + } + else + { + bit.Should().Be(0); + } + } + } + } + [Fact] public void GetHashCode_should_return_zero() { var x = new BsonClassMapSerializer(__classMap1); var result = x.GetHashCode(); - result.Should().Be(0); } + private IBsonSerializer BuildTypeAndGetSerializer(string propertyNamePrefix, int propertiesCount) + { + var assemblyName = new AssemblyName("DynamicAssembly"); + var assemblyBuilder = AssemblyBuilder.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.Run); + var moduleBuilder = assemblyBuilder.DefineDynamicModule("DynamicModule"); + + var typeBuilder = moduleBuilder.DefineType($"MyDynamicClass_{propertiesCount}", TypeAttributes.Public | TypeAttributes.Class); + + for (var i = 0; i < propertiesCount; i++) + { + _ = typeBuilder.DefineField($"{propertyNamePrefix}_{i}", + typeof(string), + FieldAttributes.Public); + } + + var newType = typeBuilder.CreateType(); + + var classMap = new BsonClassMap(newType); + for (var i = 0; i < propertiesCount; i++) + { + classMap + .MapField($"Prop_{i}") + .SetIsRequired(true); + } + classMap.Freeze(); + + var classMapSerializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newType); + var classMapSerializer = (IBsonSerializer)Activator.CreateInstance(classMapSerializerType, classMap); + + return classMapSerializer; + } + // nested classes private class MyModel {