diff --git a/src/Microsoft.AspNet.Mvc.Core/Formatters/JsonContractResolver.cs b/src/Microsoft.AspNet.Mvc.Core/Formatters/JsonContractResolver.cs index 2be2872abc..775938d478 100644 --- a/src/Microsoft.AspNet.Mvc.Core/Formatters/JsonContractResolver.cs +++ b/src/Microsoft.AspNet.Mvc.Core/Formatters/JsonContractResolver.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.ComponentModel.DataAnnotations; using System.Reflection; using Newtonsoft.Json; diff --git a/src/Microsoft.AspNet.Mvc.Core/Properties/Resources.Designer.cs b/src/Microsoft.AspNet.Mvc.Core/Properties/Resources.Designer.cs index 62dc4be9f4..d6212aa949 100644 --- a/src/Microsoft.AspNet.Mvc.Core/Properties/Resources.Designer.cs +++ b/src/Microsoft.AspNet.Mvc.Core/Properties/Resources.Designer.cs @@ -1738,6 +1738,54 @@ internal static string FormatModelType_WrongType(object p0, object p1) return string.Format(CultureInfo.CurrentCulture, GetString("ModelType_WrongType"), p0, p1); } + /// + /// The '{0}' cannot serialize an object of type '{1}' to session state. + /// + internal static string TempData_CannotSerializeToSession + { + get { return GetString("TempData_CannotSerializeToSession"); } + } + + /// + /// The '{0}' cannot serialize an object of type '{1}' to session state. + /// + internal static string FormatTempData_CannotSerializeToSession(object p0, object p1) + { + return string.Format(CultureInfo.CurrentCulture, GetString("TempData_CannotSerializeToSession"), p0, p1); + } + + /// + /// Cannot deserialize {0} of type '{1}'. + /// + internal static string TempData_CannotDeserializeToken + { + get { return GetString("TempData_CannotDeserializeToken"); } + } + + /// + /// Cannot deserialize {0} of type '{1}'. + /// + internal static string FormatTempData_CannotDeserializeToken(object p0, object p1) + { + return string.Format(CultureInfo.CurrentCulture, GetString("TempData_CannotDeserializeToken"), p0, p1); + } + + /// + /// The '{0}' cannot serialize a dictionary with a key of type '{1}' to session state. + /// + internal static string TempData_CannotSerializeDictionary + { + get { return GetString("TempData_CannotSerializeDictionary"); } + } + + /// + /// The '{0}' cannot serialize a dictionary with a key of type '{1}' to session state. + /// + internal static string FormatTempData_CannotSerializeDictionary(object p0, object p1) + { + return string.Format(CultureInfo.CurrentCulture, GetString("TempData_CannotSerializeDictionary"), p0, p1); + } + private static string GetString(string name, params string[] formatterNames) { var value = _resourceManager.GetString(name); diff --git a/src/Microsoft.AspNet.Mvc.Core/Resources.resx b/src/Microsoft.AspNet.Mvc.Core/Resources.resx index 2ad230d2c9..7db062ae64 100644 --- a/src/Microsoft.AspNet.Mvc.Core/Resources.resx +++ b/src/Microsoft.AspNet.Mvc.Core/Resources.resx @@ -451,4 +451,13 @@ The model's runtime type '{0}' is not assignable to the type '{1}'. + + The '{0}' cannot serialize an object of type '{1}' to session state. + + + Cannot deserialize {0} of type '{1}'. + + + The '{0}' cannot serialize a dictionary with a key of type '{1}' to session state. + \ No newline at end of file diff --git a/src/Microsoft.AspNet.Mvc.Core/SessionStateTempDataProvider.cs b/src/Microsoft.AspNet.Mvc.Core/SessionStateTempDataProvider.cs index 7152a733cb..6541843eef 100644 --- a/src/Microsoft.AspNet.Mvc.Core/SessionStateTempDataProvider.cs +++ b/src/Microsoft.AspNet.Mvc.Core/SessionStateTempDataProvider.cs @@ -2,12 +2,18 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; +using System.Diagnostics; using System.IO; +using System.Linq; +using System.Reflection; using Microsoft.AspNet.Http; +using Microsoft.AspNet.Mvc.Core; using Microsoft.Framework.Internal; using Newtonsoft.Json; using Newtonsoft.Json.Bson; +using Newtonsoft.Json.Linq; namespace Microsoft.AspNet.Mvc { @@ -16,8 +22,34 @@ namespace Microsoft.AspNet.Mvc /// public class SessionStateTempDataProvider : ITempDataProvider { - private static JsonSerializer jsonSerializer = new JsonSerializer(); - private static string TempDataSessionStateKey = "__ControllerTempData"; + private const string TempDataSessionStateKey = "__ControllerTempData"; + private readonly JsonSerializer _jsonSerializer = JsonSerializer.Create( + new JsonSerializerSettings() + { + TypeNameHandling = TypeNameHandling.None + }); + + private static readonly MethodInfo _convertArrayMethodInfo = typeof(SessionStateTempDataProvider).GetMethod( + nameof(ConvertArray), BindingFlags.Static | BindingFlags.NonPublic); + private static readonly MethodInfo _convertDictMethodInfo = typeof(SessionStateTempDataProvider).GetMethod( + nameof(ConvertDictionary), BindingFlags.Static | BindingFlags.NonPublic); + + private static readonly ConcurrentDictionary> _arrayConverters = + new ConcurrentDictionary>(); + private static readonly ConcurrentDictionary> _dictionaryConverters = + new ConcurrentDictionary>(); + + private static readonly Dictionary _tokenTypeLookup = new Dictionary + { + { JTokenType.String, typeof(string) }, + { JTokenType.Integer, typeof(int) }, + { JTokenType.Boolean, typeof(bool) }, + { JTokenType.Float, typeof(float) }, + { JTokenType.Guid, typeof(Guid) }, + { JTokenType.Date, typeof(DateTime) }, + { JTokenType.TimeSpan, typeof(TimeSpan) }, + { JTokenType.Uri, typeof(Uri) }, + }; /// public virtual IDictionary LoadTempData([NotNull] HttpContext context) @@ -37,9 +69,68 @@ public virtual IDictionary LoadTempData([NotNull] HttpContext co using (var memoryStream = new MemoryStream(value)) using (var writer = new BsonReader(memoryStream)) { - tempDataDictionary = jsonSerializer.Deserialize>(writer); + tempDataDictionary = _jsonSerializer.Deserialize>(writer); } + var convertedDictionary = new Dictionary(tempDataDictionary, StringComparer.OrdinalIgnoreCase); + foreach (var item in tempDataDictionary) + { + var jArrayValue = item.Value as JArray; + if (jArrayValue != null && jArrayValue.Count > 0) + { + var arrayType = jArrayValue[0].Type; + Type returnType; + if (_tokenTypeLookup.TryGetValue(arrayType, out returnType)) + { + var arrayConverter = _arrayConverters.GetOrAdd(returnType, type => + { + return (Func)_convertArrayMethodInfo.MakeGenericMethod(type).CreateDelegate(typeof(Func)); + }); + var result = arrayConverter(jArrayValue); + + convertedDictionary[item.Key] = result; + } + else + { + var message = Resources.FormatTempData_CannotDeserializeToken(nameof(JToken), arrayType); + throw new InvalidOperationException(message); + } + } + else + { + var jObjectValue = item.Value as JObject; + if (jObjectValue == null) + { + continue; + } + else if (!jObjectValue.HasValues) + { + convertedDictionary[item.Key] = null; + continue; + } + + var jTokenType = jObjectValue.Properties().First().Value.Type; + Type valueType; + if (_tokenTypeLookup.TryGetValue(jTokenType, out valueType)) + { + var dictionaryConverter = _dictionaryConverters.GetOrAdd(valueType, type => + { + return (Func)_convertDictMethodInfo.MakeGenericMethod(type).CreateDelegate(typeof(Func)); + }); + var result = dictionaryConverter(jObjectValue); + + convertedDictionary[item.Key] = result; + } + else + { + var message = Resources.FormatTempData_CannotDeserializeToken(nameof(JToken), jTokenType); + throw new InvalidOperationException(message); + } + } + } + + tempDataDictionary = convertedDictionary; + // If we got it from Session, remove it so that no other request gets it session.Remove(TempDataSessionStateKey); } @@ -59,13 +150,19 @@ public virtual void SaveTempData([NotNull] HttpContext context, IDictionary 0); if (hasValues) { + foreach (var item in values.Values) + { + // We want to allow only simple types to be serialized in session. + EnsureObjectCanBeSerialized(item); + } + // Accessing Session property will throw if the session middleware is not enabled. var session = context.Session; using (var memoryStream = new MemoryStream()) using (var writer = new BsonWriter(memoryStream)) { - jsonSerializer.Serialize(writer, values); + _jsonSerializer.Serialize(writer, values); session[TempDataSessionStateKey] = memoryStream.ToArray(); } } @@ -80,5 +177,65 @@ private static bool IsSessionEnabled(HttpContext context) { return context.GetFeature() != null; } + + internal void EnsureObjectCanBeSerialized(object item) + { + var itemType = item.GetType(); + Type actualType = null; + + if (itemType.IsArray) + { + itemType = itemType.GetElementType(); + } + else if (itemType.GetTypeInfo().IsGenericType) + { + if (itemType.ExtractGenericInterface(typeof(IList<>)) != null) + { + var genericTypeArguments = itemType.GetGenericArguments(); + Debug.Assert(genericTypeArguments.Length == 1, "IList has one generic argument"); + actualType = genericTypeArguments[0]; + } + else if (itemType.ExtractGenericInterface(typeof(IDictionary<,>)) != null) + { + var genericTypeArguments = itemType.GetGenericArguments(); + Debug.Assert(genericTypeArguments.Length == 2, "IDictionary has two generic arguments"); + // Throw if the key type of the dictionary is not string. + if (genericTypeArguments[0] != typeof(string)) + { + var message = Resources.FormatTempData_CannotSerializeDictionary( + typeof(SessionStateTempDataProvider).FullName, genericTypeArguments[0]); + throw new InvalidOperationException(message); + } + else + { + actualType = genericTypeArguments[1]; + } + } + } + + actualType = actualType ?? itemType; + if (!TypeHelper.IsSimpleType(actualType)) + { + var underlyingType = Nullable.GetUnderlyingType(actualType) ?? actualType; + var message = Resources.FormatTempData_CannotSerializeToSession( + typeof(SessionStateTempDataProvider).FullName, underlyingType); + throw new InvalidOperationException(message); + } + } + + private static IList ConvertArray(JArray array) + { + return array.Values().ToArray(); + } + + private static IDictionary ConvertDictionary(JObject jObject) + { + var convertedDictionary = new Dictionary(StringComparer.Ordinal); + foreach (var item in jObject) + { + convertedDictionary.Add(item.Key, jObject.Value(item.Key)); + } + return convertedDictionary; + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNet.Mvc/MvcServices.cs b/src/Microsoft.AspNet.Mvc/MvcServices.cs index 381f81bb97..188463de53 100644 --- a/src/Microsoft.AspNet.Mvc/MvcServices.cs +++ b/src/Microsoft.AspNet.Mvc/MvcServices.cs @@ -166,8 +166,9 @@ public static IServiceCollection GetDefaultServices() services.AddTransient(); // Temp Data - services.AddSingleton(); services.AddScoped(); + // This does caching so it should stay singleton + services.AddSingleton(); return services; } diff --git a/test/Microsoft.AspNet.Mvc.Core.Test/SessionStateTempDataProviderTest.cs b/test/Microsoft.AspNet.Mvc.Core.Test/SessionStateTempDataProviderTest.cs index 1c266163d7..6291afe2f9 100644 --- a/test/Microsoft.AspNet.Mvc.Core.Test/SessionStateTempDataProviderTest.cs +++ b/test/Microsoft.AspNet.Mvc.Core.Test/SessionStateTempDataProviderTest.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; using System.Collections.Generic; using Microsoft.AspNet.Http; using Moq; @@ -76,6 +77,151 @@ public void Save_NullSession_NonEmptyDictionary_Throws() }); } + public static TheoryData InvalidTypes + { + get + { + return new TheoryData + { + { new object(), typeof(object) }, + { new object[3], typeof(object) }, + { new TestItem(), typeof(TestItem) }, + { new List(), typeof(TestItem) }, + { new Dictionary(), typeof(TestItem) }, + }; + } + } + + [Theory] + [MemberData(nameof(InvalidTypes))] + public void EnsureObjectCanBeSerialized_InvalidType_Throws(object value, Type type) + { + // Arrange + var testProvider = new SessionStateTempDataProvider(); + + // Act & Assert + var exception = Assert.Throws(() => + { + testProvider.EnsureObjectCanBeSerialized(value); + }); + Assert.Equal($"The '{typeof(SessionStateTempDataProvider).FullName}' cannot serialize an object of type '{type}' to session state.", + exception.Message); + } + + public static TheoryData InvalidDictionaryTypes + { + get + { + return new TheoryData + { + { new Dictionary(), typeof(int) }, + { new Dictionary(), typeof(Uri) }, + { new Dictionary(), typeof(object) }, + { new Dictionary(), typeof(TestItem) } + }; + } + } + + [Theory] + [MemberData(nameof(InvalidDictionaryTypes))] + public void EnsureObjectCanBeSerialized_InvalidDictionaryType_Throws(object value, Type type) + { + // Arrange + var testProvider = new SessionStateTempDataProvider(); + + // Act & Assert + var exception = Assert.Throws(() => + { + testProvider.EnsureObjectCanBeSerialized(value); + }); + Assert.Equal($"The '{typeof(SessionStateTempDataProvider).FullName}' cannot serialize a dictionary with a key of type '{type}' to session state.", + exception.Message); + } + + public static TheoryData ValidTypes + { + get + { + return new TheoryData + { + { 10 }, + { new int[]{ 10, 20 } }, + { "FooValue" }, + { new Uri("http://Foo") }, + { Guid.NewGuid() }, + { new List { "foo", "bar" } }, + { new DateTimeOffset() }, + { 100.1m }, + { new Dictionary() }, + { new Uri[] { new Uri("http://Foo"), new Uri("http://Bar") } } + }; + } + } + + [Theory] + [MemberData(nameof(ValidTypes))] + public void EnsureObjectCanBeSerialized_ValidType_DoesNotThrow(object value) + { + // Arrange + var testProvider = new SessionStateTempDataProvider(); + + // Act & Assert (Does not throw) + testProvider.EnsureObjectCanBeSerialized(value); + } + + [Fact] + public void SaveAndLoad_SimpleTypesCanBeStoredAndLoaded() + { + // Arrange + var testProvider = new SessionStateTempDataProvider(); + var inputGuid = Guid.NewGuid(); + var inputDictionary = new Dictionary + { + { "Hello", "World" }, + }; + var input = new Dictionary + { + { "string", "value" }, + { "int", 10 }, + { "bool", false }, + { "DateTime", new DateTime() }, + { "Guid", inputGuid }, + { "List`string", new List { "one", "two" } }, + { "Dictionary", inputDictionary }, + { "EmptyDictionary", new Dictionary() } + }; + var context = GetHttpContext(new TestSessionCollection(), true); + + // Act + testProvider.SaveTempData(context, input); + var TempData = testProvider.LoadTempData(context); + + // Assert + var stringVal = Assert.IsType(TempData["string"]); + Assert.Equal("value", stringVal); + var intVal = Convert.ToInt32(TempData["int"]); + Assert.Equal(10, intVal); + var boolVal = Assert.IsType(TempData["bool"]); + Assert.Equal(false, boolVal); + var datetimeVal = Assert.IsType(TempData["DateTime"]); + Assert.Equal(new DateTime().ToString(), datetimeVal.ToString()); + var guidVal = Assert.IsType(TempData["Guid"]); + Assert.Equal(inputGuid.ToString(), guidVal.ToString()); + var list = (IList)TempData["List`string"]; + Assert.Equal(2, list.Count); + Assert.Equal("one", list[0]); + Assert.Equal("two", list[1]); + var dictionary = Assert.IsType>(TempData["Dictionary"]); + Assert.Equal("World", dictionary["Hello"]); + var emptyDictionary = (IDictionary)TempData["EmptyDictionary"]; + Assert.Null(emptyDictionary); + } + + private class TestItem + { + public int DummyInt { get; set; } + } + private HttpContext GetHttpContext(ISessionCollection session, bool sessionEnabled=true) { var httpContext = new Mock(); @@ -87,12 +233,63 @@ private HttpContext GetHttpContext(ISessionCollection session, bool sessionEnabl { httpContext.Setup(h => h.Session).Throws(); } + else + { + httpContext.Setup(h => h.Session[It.IsAny()]); + } if (sessionEnabled) { httpContext.Setup(h => h.GetFeature()).Returns(Mock.Of()); - httpContext.Setup(h => h.Session[It.IsAny()]); } return httpContext.Object; } + + private class TestSessionCollection : ISessionCollection + { + private Dictionary _innerDictionary = new Dictionary(); + + public byte[] this[string key] + { + get + { + return _innerDictionary[key]; + } + + set + { + _innerDictionary[key] = value; + } + } + + public void Clear() + { + _innerDictionary.Clear(); + } + + public IEnumerator> GetEnumerator() + { + return _innerDictionary.GetEnumerator(); + } + + public void Remove(string key) + { + _innerDictionary.Remove(key); + } + + public void Set(string key, ArraySegment value) + { + _innerDictionary[key] = value.AsArray(); + } + + public bool TryGetValue(string key, out byte[] value) + { + return _innerDictionary.TryGetValue(key, out value); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _innerDictionary.GetEnumerator(); + } + } } } \ No newline at end of file diff --git a/test/Microsoft.AspNet.Mvc.FunctionalTests/TempDataTest.cs b/test/Microsoft.AspNet.Mvc.FunctionalTests/TempDataTest.cs index 39a625ef1f..0e50521f57 100644 --- a/test/Microsoft.AspNet.Mvc.FunctionalTests/TempDataTest.cs +++ b/test/Microsoft.AspNet.Mvc.FunctionalTests/TempDataTest.cs @@ -142,6 +142,61 @@ public async Task Peek_RetainsTempData() Assert.Equal("Foo", body); } + [Fact] + public async Task TempData_ValidTypes_RoundTripProperly() + { + // Arrange + var server = TestHelper.CreateServer(_app, SiteName, _configureServices); + var client = server.CreateClient(); + var testGuid = Guid.NewGuid(); + var nameValueCollection = new List> + { + new KeyValuePair("value", "Foo"), + new KeyValuePair("intValue", "10"), + new KeyValuePair("listValues", "Foo1"), + new KeyValuePair("listValues", "Foo2"), + new KeyValuePair("listValues", "Foo3"), + new KeyValuePair("datetimeValue", "10/10/2010"), + new KeyValuePair("guidValue", testGuid.ToString()), + }; + var content = new FormUrlEncodedContent(nameValueCollection); + + // Act 1 + var redirectResponse = await client.PostAsync("/Home/SetTempDataMultiple", content); + + // Assert 1 + Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); + + // Act 2 + var response = await client.SendAsync(GetRequest(redirectResponse.Headers.Location.ToString(), redirectResponse)); + + // Assert 2 + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + var body = await response.Content.ReadAsStringAsync(); + Assert.Equal($"Foo 10 3 10/10/2010 00:00:00 {testGuid.ToString()}", body); + } + + [Fact] + public async Task TempData_InvalidType_Throws() + { + // Arrange + var server = TestHelper.CreateServer(_app, SiteName, _configureServices); + var client = server.CreateClient(); + var nameValueCollection = new List> + { + new KeyValuePair("value", "Foo"), + }; + var content = new FormUrlEncodedContent(nameValueCollection); + + // Act & Assert + var exception = await Assert.ThrowsAsync(async () => + { + await client.PostAsync("/Home/SetTempDataInvalidType", content); + }); + Assert.Equal("The '" + typeof(SessionStateTempDataProvider).FullName + "' cannot serialize an object of type '" + + typeof(TempDataWebSite.Controllers.HomeController.NonSerializableType).FullName + "' to session state.", exception.Message); + } + private HttpRequestMessage GetRequest(string path, HttpResponseMessage response) { var request = new HttpRequestMessage(HttpMethod.Get, path); diff --git a/test/WebSites/TempDataWebSite/Controllers/HomeController.cs b/test/WebSites/TempDataWebSite/Controllers/HomeController.cs index 732a8c269d..ae7f663267 100644 --- a/test/WebSites/TempDataWebSite/Controllers/HomeController.cs +++ b/test/WebSites/TempDataWebSite/Controllers/HomeController.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Collections.Generic; using Microsoft.AspNet.Mvc; namespace TempDataWebSite.Controllers @@ -41,5 +43,49 @@ public IActionResult PeekTempData() var peekValue = TempData.Peek("key"); return Content(peekValue.ToString()); } + + public IActionResult SetTempDataMultiple( + string value, + int intValue, + IList listValues, + DateTime datetimeValue, + Guid guidValue) + { + TempData["key1"] = value; + TempData["key2"] = intValue; + TempData["key3"] = listValues; + TempData["key4"] = datetimeValue; + TempData["key5"] = guidValue; + return RedirectToAction("GetTempDataMultiple"); + } + + public string GetTempDataMultiple() + { + var value1 = TempData["key1"].ToString(); + var value2 = Convert.ToInt32(TempData["key2"]); + var value3 = (IList)TempData["key3"]; + var value4 = (DateTime)TempData["key4"]; + var value5 = (Guid)TempData["key5"]; + return $"{value1} {value2.ToString()} {value3.Count.ToString()} {value4.ToString()} {value5.ToString()}"; + } + + public string SetTempDataInvalidType() + { + var exception = ""; + try + { + TempData["key"] = new NonSerializableType(); + } + catch (Exception e) + { + exception = e.Message; + } + + return exception; + } + + public class NonSerializableType + { + } } }