diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs index 5fcd07703ccb09..6e1b1d44ba616c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs @@ -468,13 +468,13 @@ public override StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalState private List GenerateElementStages( StubIdentifierContext context, IBoundMarshallingGenerator elementMarshaller, - out LinearCollectionElementIdentifierContext elementSetupSubContext, + out string indexer, params StubIdentifierContext.Stage[] stagesToGeneratePerElement) { string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(CollectionSource.TypeInfo, context); string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(CollectionSource.TypeInfo, context); StubCodeContext elementCodeContext = StubCodeContext.CreateElementMarshallingContext(CollectionSource.CodeContext); - elementSetupSubContext = new LinearCollectionElementIdentifierContext( + LinearCollectionElementIdentifierContext elementSetupSubContext = new( context, elementMarshaller.TypeInfo, managedSpanIdentifier, @@ -485,13 +485,74 @@ private List GenerateElementStages( CodeEmitOptions = context.CodeEmitOptions }; + indexer = elementSetupSubContext.IndexerIdentifier; + + StubIdentifierContext identifierContext = elementSetupSubContext; + + if (elementMarshaller.NativeType is PointerTypeInfo) + { + identifierContext = new GenericFriendlyPointerIdentifierContext(elementSetupSubContext, elementMarshaller.TypeInfo, $"{nativeSpanIdentifier}__{indexer}") + { + CodeEmitOptions = elementSetupSubContext.CodeEmitOptions, + }; + } + List elementStatements = []; foreach (StubIdentifierContext.Stage stage in stagesToGeneratePerElement) { - var elementSubContext = elementSetupSubContext with { CurrentStage = stage }; - elementStatements.AddRange(elementMarshaller.Generate(elementSubContext)); + var elementIdentifierContext = identifierContext with { CurrentStage = stage }; + elementStatements.AddRange(elementMarshaller.Generate(elementIdentifierContext)); + } + + if (elementStatements.Count == 0) + { + return []; + } + + // Only add the setup stage if we generated code for other stages. + elementStatements.InsertRange(0, elementMarshaller.Generate(identifierContext with { CurrentStage = StubIdentifierContext.Stage.Setup })); + + if (identifierContext is not GenericFriendlyPointerIdentifierContext) + { + // If we didn't need to account for pointer types, we have the statements we need. + return elementStatements; } - return elementStatements; + + // If we have the generic friendly pointer context, we need to declare the special identifier and assign to/from it. + + // = ()[i]; + StatementSyntax exactTypeDeclaration = + LocalDeclarationStatement( + VariableDeclaration( + elementMarshaller.NativeType.Syntax, + SingletonSeparatedList( + VariableDeclarator( + Identifier(identifierContext.GetIdentifiers(elementMarshaller.TypeInfo).native)) + .WithInitializer( + EqualsValueClause( + CastExpression(elementMarshaller.NativeType.Syntax, + ParseExpression(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native))))))); + + if (stagesToGeneratePerElement.Any(stage => stage is StubIdentifierContext.Stage.Marshal or StubIdentifierContext.Stage.PinnedMarshal)) + { + // [i] = (); + StatementSyntax propagateResult = AssignmentStatement( + ParseExpression(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native), + CastExpression(TypeSyntaxes.System_IntPtr, + IdentifierName(identifierContext.GetIdentifiers(elementMarshaller.TypeInfo).native))); + + return + [ + exactTypeDeclaration, + ..elementStatements, + propagateResult + ]; + } + + return [ + exactTypeDeclaration, + ..elementStatements + ]; } private StatementSyntax GenerateContentsMarshallingStatement( @@ -500,22 +561,14 @@ private StatementSyntax GenerateContentsMarshallingStatement( IBoundMarshallingGenerator elementMarshaller, params StubIdentifierContext.Stage[] stagesToGeneratePerElement) { - var elementStatements = GenerateElementStages(context, elementMarshaller, out var elementSetupSubContext, stagesToGeneratePerElement); + var elementStatements = GenerateElementStages(context, elementMarshaller, out string indexer, stagesToGeneratePerElement); if (elementStatements.Count != 0) { - StatementSyntax marshallingStatement = Block( - List(elementMarshaller.Generate(elementSetupSubContext) - .Concat(elementStatements))); - - if (elementMarshaller.NativeType is PointerTypeInfo nativeTypeInfo) - { - PointerNativeTypeAssignmentRewriter rewriter = new(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native, (PointerTypeSyntax)nativeTypeInfo.Syntax); - marshallingStatement = (StatementSyntax)rewriter.Visit(marshallingStatement); - } + StatementSyntax marshallingStatement = Block(elementStatements); // Iterate through the elements of the native collection to marshal them - var forLoop = ForLoop(elementSetupSubContext.IndexerIdentifier, lengthExpression) + var forLoop = ForLoop(indexer, lengthExpression) .WithStatement(marshallingStatement); // If we're tracking LastIndexMarshalled, increment that each iteration as well. if (UsesLastIndexMarshalled(CollectionSource.TypeInfo, CollectionSource.CodeContext) && stagesToGeneratePerElement.Contains(StubIdentifierContext.Stage.Marshal)) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/GenericFriendlyPointerIdentifierContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/GenericFriendlyPointerIdentifierContext.cs new file mode 100644 index 00000000000000..8400b767e57360 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/GenericFriendlyPointerIdentifierContext.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace Microsoft.Interop +{ + internal sealed record GenericFriendlyPointerIdentifierContext : StubIdentifierContext + { + private readonly StubIdentifierContext _innerContext; + private readonly TypePositionInfo _adaptedInfo; + private readonly string _nativeIdentifier; + + public GenericFriendlyPointerIdentifierContext(StubIdentifierContext inner, TypePositionInfo adaptedInfo, string baseIdentifier) + { + _innerContext = inner; + _adaptedInfo = adaptedInfo; + _nativeIdentifier = baseIdentifier + "_exactType"; + CurrentStage = inner.CurrentStage; + } + + public override (string managed, string native) GetIdentifiers(TypePositionInfo info) + { + if (info.PositionsEqual(_adaptedInfo)) + { + (string managed, _) = _innerContext.GetIdentifiers(info); + return (managed, _nativeIdentifier); + } + + return _innerContext.GetIdentifiers(info); + } + + public override string GetAdditionalIdentifier(TypePositionInfo info, string name) => _innerContext.GetAdditionalIdentifier(info, name); + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/PointerNativeTypeAssignmentRewriter.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/PointerNativeTypeAssignmentRewriter.cs deleted file mode 100644 index 5f4f5d43a4a3df..00000000000000 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/PointerNativeTypeAssignmentRewriter.cs +++ /dev/null @@ -1,75 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Microsoft.Interop -{ - /// - /// Rewrite assignment expressions to the native identifier to cast to IntPtr. - /// This handles the case where the native type of a non-blittable managed type is a pointer, - /// which are unsupported in generic type parameters. - /// - internal sealed class PointerNativeTypeAssignmentRewriter : CSharpSyntaxRewriter - { - private readonly string _nativeIdentifier; - private readonly PointerTypeSyntax _nativeType; - - public PointerNativeTypeAssignmentRewriter(string nativeIdentifier, PointerTypeSyntax nativeType) - { - _nativeIdentifier = nativeIdentifier; - _nativeType = nativeType; - } - - public override SyntaxNode? VisitVariableDeclarator(VariableDeclaratorSyntax node) - { - if (node.Initializer is null) - { - return base.VisitVariableDeclarator(node); - } - - if (node.Identifier.ToString() == _nativeIdentifier) - { - return node.WithInitializer( - EqualsValueClause( - CastExpression(TypeSyntaxes.System_IntPtr, node.Initializer.Value))); - } - if (node.Initializer.Value.ToString() == _nativeIdentifier) - { - return node.WithInitializer( - EqualsValueClause( - CastExpression(_nativeType, node.Initializer.Value))); - } - - return base.VisitVariableDeclarator(node); - } - - public override SyntaxNode VisitAssignmentExpression(AssignmentExpressionSyntax node) - { - if (node.Left.ToString() == _nativeIdentifier) - { - return node.WithRight( - CastExpression(TypeSyntaxes.System_IntPtr, node.Right)); - } - if (node.Right.ToString() == _nativeIdentifier) - { - return node.WithRight(CastExpression(_nativeType, node.Right)); - } - - return base.VisitAssignmentExpression(node); - } - - public override SyntaxNode? VisitArgument(ArgumentSyntax node) - { - if (node.Expression.ToString() == _nativeIdentifier) - { - return node.WithExpression( - CastExpression(_nativeType, node.Expression)); - } - return base.VisitArgument(node); - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index ad4135078dc4f3..da13f490614238 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -7,6 +7,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static Microsoft.Interop.SyntaxFactoryExtensions; namespace Microsoft.Interop { @@ -71,18 +72,26 @@ public IEnumerable GenerateMarshalStatements(StubIdentifierCont yield break; } - // = ; - var assignment = AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(nativeIdentifier), - convertToUnmanaged); - + ExpressionSyntax assignment; - if (unmanagedType is PointerTypeInfo pointer) + // For some of our exception marshallers, our marshaller returns nint for pointer types. + // As a result, we need to insert a cast here in case we're in that scenario (which we can't detect specifically). + if (unmanagedType is PointerTypeInfo ptrType) + { + // = (); + assignment = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(nativeIdentifier), + CastExpression(ptrType.Syntax, convertToUnmanaged)); + } + else { - var rewriter = new PointerNativeTypeAssignmentRewriter(assignment.Right.ToString(), (PointerTypeSyntax)pointer.Syntax); - assignment = (AssignmentExpressionSyntax)rewriter.Visit(assignment); + // = ; + assignment = AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(nativeIdentifier), + convertToUnmanaged); } + yield return ExpressionStatement(assignment); }