Skip to content

Commit d6da9cc

Browse files
jkoritzinskysirntar
authored andcommitted
Remove PointerNativeTypeAssignmentRewriter by introducing a separate local with the exact type and assigning to/from it (dotnet#107219)
1 parent d10088a commit d6da9cc

File tree

4 files changed

+125
-100
lines changed

4 files changed

+125
-100
lines changed

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,13 @@ public override StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalState
468468
private List<StatementSyntax> GenerateElementStages(
469469
StubIdentifierContext context,
470470
IBoundMarshallingGenerator elementMarshaller,
471-
out LinearCollectionElementIdentifierContext elementSetupSubContext,
471+
out string indexer,
472472
params StubIdentifierContext.Stage[] stagesToGeneratePerElement)
473473
{
474474
string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(CollectionSource.TypeInfo, context);
475475
string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(CollectionSource.TypeInfo, context);
476476
StubCodeContext elementCodeContext = StubCodeContext.CreateElementMarshallingContext(CollectionSource.CodeContext);
477-
elementSetupSubContext = new LinearCollectionElementIdentifierContext(
477+
LinearCollectionElementIdentifierContext elementSetupSubContext = new(
478478
context,
479479
elementMarshaller.TypeInfo,
480480
managedSpanIdentifier,
@@ -485,13 +485,74 @@ private List<StatementSyntax> GenerateElementStages(
485485
CodeEmitOptions = context.CodeEmitOptions
486486
};
487487

488+
indexer = elementSetupSubContext.IndexerIdentifier;
489+
490+
StubIdentifierContext identifierContext = elementSetupSubContext;
491+
492+
if (elementMarshaller.NativeType is PointerTypeInfo)
493+
{
494+
identifierContext = new GenericFriendlyPointerIdentifierContext(elementSetupSubContext, elementMarshaller.TypeInfo, $"{nativeSpanIdentifier}__{indexer}")
495+
{
496+
CodeEmitOptions = elementSetupSubContext.CodeEmitOptions,
497+
};
498+
}
499+
488500
List<StatementSyntax> elementStatements = [];
489501
foreach (StubIdentifierContext.Stage stage in stagesToGeneratePerElement)
490502
{
491-
var elementSubContext = elementSetupSubContext with { CurrentStage = stage };
492-
elementStatements.AddRange(elementMarshaller.Generate(elementSubContext));
503+
var elementIdentifierContext = identifierContext with { CurrentStage = stage };
504+
elementStatements.AddRange(elementMarshaller.Generate(elementIdentifierContext));
505+
}
506+
507+
if (elementStatements.Count == 0)
508+
{
509+
return [];
510+
}
511+
512+
// Only add the setup stage if we generated code for other stages.
513+
elementStatements.InsertRange(0, elementMarshaller.Generate(identifierContext with { CurrentStage = StubIdentifierContext.Stage.Setup }));
514+
515+
if (identifierContext is not GenericFriendlyPointerIdentifierContext)
516+
{
517+
// If we didn't need to account for pointer types, we have the statements we need.
518+
return elementStatements;
493519
}
494-
return elementStatements;
520+
521+
// If we have the generic friendly pointer context, we need to declare the special identifier and assign to/from it.
522+
523+
// <native_type> <native_exactType> = (<native_type>)<native_collection>[i];
524+
StatementSyntax exactTypeDeclaration =
525+
LocalDeclarationStatement(
526+
VariableDeclaration(
527+
elementMarshaller.NativeType.Syntax,
528+
SingletonSeparatedList(
529+
VariableDeclarator(
530+
Identifier(identifierContext.GetIdentifiers(elementMarshaller.TypeInfo).native))
531+
.WithInitializer(
532+
EqualsValueClause(
533+
CastExpression(elementMarshaller.NativeType.Syntax,
534+
ParseExpression(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native)))))));
535+
536+
if (stagesToGeneratePerElement.Any(stage => stage is StubIdentifierContext.Stage.Marshal or StubIdentifierContext.Stage.PinnedMarshal))
537+
{
538+
// <native_collection>[i] = (<generic_compatible_native_type>)<native_exactType>;
539+
StatementSyntax propagateResult = AssignmentStatement(
540+
ParseExpression(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native),
541+
CastExpression(TypeSyntaxes.System_IntPtr,
542+
IdentifierName(identifierContext.GetIdentifiers(elementMarshaller.TypeInfo).native)));
543+
544+
return
545+
[
546+
exactTypeDeclaration,
547+
..elementStatements,
548+
propagateResult
549+
];
550+
}
551+
552+
return [
553+
exactTypeDeclaration,
554+
..elementStatements
555+
];
495556
}
496557

497558
private StatementSyntax GenerateContentsMarshallingStatement(
@@ -500,22 +561,14 @@ private StatementSyntax GenerateContentsMarshallingStatement(
500561
IBoundMarshallingGenerator elementMarshaller,
501562
params StubIdentifierContext.Stage[] stagesToGeneratePerElement)
502563
{
503-
var elementStatements = GenerateElementStages(context, elementMarshaller, out var elementSetupSubContext, stagesToGeneratePerElement);
564+
var elementStatements = GenerateElementStages(context, elementMarshaller, out string indexer, stagesToGeneratePerElement);
504565

505566
if (elementStatements.Count != 0)
506567
{
507-
StatementSyntax marshallingStatement = Block(
508-
List(elementMarshaller.Generate(elementSetupSubContext)
509-
.Concat(elementStatements)));
510-
511-
if (elementMarshaller.NativeType is PointerTypeInfo nativeTypeInfo)
512-
{
513-
PointerNativeTypeAssignmentRewriter rewriter = new(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native, (PointerTypeSyntax)nativeTypeInfo.Syntax);
514-
marshallingStatement = (StatementSyntax)rewriter.Visit(marshallingStatement);
515-
}
568+
StatementSyntax marshallingStatement = Block(elementStatements);
516569

517570
// Iterate through the elements of the native collection to marshal them
518-
var forLoop = ForLoop(elementSetupSubContext.IndexerIdentifier, lengthExpression)
571+
var forLoop = ForLoop(indexer, lengthExpression)
519572
.WithStatement(marshallingStatement);
520573
// If we're tracking LastIndexMarshalled, increment that each iteration as well.
521574
if (UsesLastIndexMarshalled(CollectionSource.TypeInfo, CollectionSource.CodeContext) && stagesToGeneratePerElement.Contains(StubIdentifierContext.Stage.Marshal))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Diagnostics;
7+
using System.Text;
8+
9+
namespace Microsoft.Interop
10+
{
11+
internal sealed record GenericFriendlyPointerIdentifierContext : StubIdentifierContext
12+
{
13+
private readonly StubIdentifierContext _innerContext;
14+
private readonly TypePositionInfo _adaptedInfo;
15+
private readonly string _nativeIdentifier;
16+
17+
public GenericFriendlyPointerIdentifierContext(StubIdentifierContext inner, TypePositionInfo adaptedInfo, string baseIdentifier)
18+
{
19+
_innerContext = inner;
20+
_adaptedInfo = adaptedInfo;
21+
_nativeIdentifier = baseIdentifier + "_exactType";
22+
CurrentStage = inner.CurrentStage;
23+
}
24+
25+
public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
26+
{
27+
if (info.PositionsEqual(_adaptedInfo))
28+
{
29+
(string managed, _) = _innerContext.GetIdentifiers(info);
30+
return (managed, _nativeIdentifier);
31+
}
32+
33+
return _innerContext.GetIdentifiers(info);
34+
}
35+
36+
public override string GetAdditionalIdentifier(TypePositionInfo info, string name) => _innerContext.GetAdditionalIdentifier(info, name);
37+
}
38+
}

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/PointerNativeTypeAssignmentRewriter.cs

Lines changed: 0 additions & 75 deletions
This file was deleted.

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.CodeAnalysis.CSharp;
88
using Microsoft.CodeAnalysis.CSharp.Syntax;
99
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
10+
using static Microsoft.Interop.SyntaxFactoryExtensions;
1011

1112
namespace Microsoft.Interop
1213
{
@@ -71,18 +72,26 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(StubIdentifierCont
7172
yield break;
7273
}
7374

74-
// <nativeIdentifier> = <convertToUnmanaged>;
75-
var assignment = AssignmentExpression(
76-
SyntaxKind.SimpleAssignmentExpression,
77-
IdentifierName(nativeIdentifier),
78-
convertToUnmanaged);
79-
75+
ExpressionSyntax assignment;
8076

81-
if (unmanagedType is PointerTypeInfo pointer)
77+
// For some of our exception marshallers, our marshaller returns nint for pointer types.
78+
// As a result, we need to insert a cast here in case we're in that scenario (which we can't detect specifically).
79+
if (unmanagedType is PointerTypeInfo ptrType)
80+
{
81+
// <nativeIdentifier> = (<nativeType>)<convertToUnmanaged>;
82+
assignment = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
83+
IdentifierName(nativeIdentifier),
84+
CastExpression(ptrType.Syntax, convertToUnmanaged));
85+
}
86+
else
8287
{
83-
var rewriter = new PointerNativeTypeAssignmentRewriter(assignment.Right.ToString(), (PointerTypeSyntax)pointer.Syntax);
84-
assignment = (AssignmentExpressionSyntax)rewriter.Visit(assignment);
88+
// <nativeIdentifier> = <convertToUnmanaged>;
89+
assignment = AssignmentExpression(
90+
SyntaxKind.SimpleAssignmentExpression,
91+
IdentifierName(nativeIdentifier),
92+
convertToUnmanaged);
8593
}
94+
8695
yield return ExpressionStatement(assignment);
8796
}
8897

0 commit comments

Comments
 (0)